CSS-LM
1983 строки · 86.8 Кб
1# coding=utf-8
2# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
3# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16""" PyTorch XLNet model.
17"""
18
19
20import logging
21from dataclasses import dataclass
22from typing import List, Optional, Tuple
23
24import torch
25from torch import nn
26from torch.nn import CrossEntropyLoss, MSELoss
27from torch.nn import functional as F
28
29from .activations import gelu_new, swish
30from .configuration_xlnet import XLNetConfig
31from .file_utils import (
32ModelOutput,
33add_code_sample_docstrings,
34add_start_docstrings,
35add_start_docstrings_to_callable,
36replace_return_docstrings,
37)
38from .modeling_utils import PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits, PreTrainedModel, SequenceSummary
39
40
41logger = logging.getLogger(__name__)
42
43_CONFIG_FOR_DOC = "XLNetConfig"
44_TOKENIZER_FOR_DOC = "XLNetTokenizer"
45
46XLNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
47"xlnet-base-cased",
48"xlnet-large-cased",
49# See all XLNet models at https://huggingface.co/models?filter=xlnet
50]
51
52
53def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None):
54""" A map of modules from TF to PyTorch.
55I use a map to keep the PyTorch model as
56identical to the original PyTorch model as possible.
57"""
58
59tf_to_pt_map = {}
60
61if hasattr(model, "transformer"):
62if hasattr(model, "lm_loss"):
63# We will load also the output bias
64tf_to_pt_map["model/lm_loss/bias"] = model.lm_loss.bias
65if hasattr(model, "sequence_summary") and "model/sequnece_summary/summary/kernel" in tf_weights:
66# We will load also the sequence summary
67tf_to_pt_map["model/sequnece_summary/summary/kernel"] = model.sequence_summary.summary.weight
68tf_to_pt_map["model/sequnece_summary/summary/bias"] = model.sequence_summary.summary.bias
69if (
70hasattr(model, "logits_proj")
71and config.finetuning_task is not None
72and "model/regression_{}/logit/kernel".format(config.finetuning_task) in tf_weights
73):
74tf_to_pt_map["model/regression_{}/logit/kernel".format(config.finetuning_task)] = model.logits_proj.weight
75tf_to_pt_map["model/regression_{}/logit/bias".format(config.finetuning_task)] = model.logits_proj.bias
76
77# Now load the rest of the transformer
78model = model.transformer
79
80# Embeddings and output
81tf_to_pt_map.update(
82{
83"model/transformer/word_embedding/lookup_table": model.word_embedding.weight,
84"model/transformer/mask_emb/mask_emb": model.mask_emb,
85}
86)
87
88# Transformer blocks
89for i, b in enumerate(model.layer):
90layer_str = "model/transformer/layer_%d/" % i
91tf_to_pt_map.update(
92{
93layer_str + "rel_attn/LayerNorm/gamma": b.rel_attn.layer_norm.weight,
94layer_str + "rel_attn/LayerNorm/beta": b.rel_attn.layer_norm.bias,
95layer_str + "rel_attn/o/kernel": b.rel_attn.o,
96layer_str + "rel_attn/q/kernel": b.rel_attn.q,
97layer_str + "rel_attn/k/kernel": b.rel_attn.k,
98layer_str + "rel_attn/r/kernel": b.rel_attn.r,
99layer_str + "rel_attn/v/kernel": b.rel_attn.v,
100layer_str + "ff/LayerNorm/gamma": b.ff.layer_norm.weight,
101layer_str + "ff/LayerNorm/beta": b.ff.layer_norm.bias,
102layer_str + "ff/layer_1/kernel": b.ff.layer_1.weight,
103layer_str + "ff/layer_1/bias": b.ff.layer_1.bias,
104layer_str + "ff/layer_2/kernel": b.ff.layer_2.weight,
105layer_str + "ff/layer_2/bias": b.ff.layer_2.bias,
106}
107)
108
109# Relative positioning biases
110if config.untie_r:
111r_r_list = []
112r_w_list = []
113r_s_list = []
114seg_embed_list = []
115for b in model.layer:
116r_r_list.append(b.rel_attn.r_r_bias)
117r_w_list.append(b.rel_attn.r_w_bias)
118r_s_list.append(b.rel_attn.r_s_bias)
119seg_embed_list.append(b.rel_attn.seg_embed)
120else:
121r_r_list = [model.r_r_bias]
122r_w_list = [model.r_w_bias]
123r_s_list = [model.r_s_bias]
124seg_embed_list = [model.seg_embed]
125tf_to_pt_map.update(
126{
127"model/transformer/r_r_bias": r_r_list,
128"model/transformer/r_w_bias": r_w_list,
129"model/transformer/r_s_bias": r_s_list,
130"model/transformer/seg_embed": seg_embed_list,
131}
132)
133return tf_to_pt_map
134
135
136def load_tf_weights_in_xlnet(model, config, tf_path):
137""" Load tf checkpoints in a pytorch model
138"""
139try:
140import numpy as np
141import tensorflow as tf
142except ImportError:
143logger.error(
144"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
145"https://www.tensorflow.org/install/ for installation instructions."
146)
147raise
148# Load weights from TF model
149init_vars = tf.train.list_variables(tf_path)
150tf_weights = {}
151for name, shape in init_vars:
152logger.info("Loading TF weight {} with shape {}".format(name, shape))
153array = tf.train.load_variable(tf_path, name)
154tf_weights[name] = array
155
156# Build TF to PyTorch weights loading map
157tf_to_pt_map = build_tf_xlnet_to_pytorch_map(model, config, tf_weights)
158
159for name, pointer in tf_to_pt_map.items():
160logger.info("Importing {}".format(name))
161if name not in tf_weights:
162logger.info("{} not in tf pre-trained weights, skipping".format(name))
163continue
164array = tf_weights[name]
165# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
166# which are not required for using pretrained model
167if "kernel" in name and ("ff" in name or "summary" in name or "logit" in name):
168logger.info("Transposing")
169array = np.transpose(array)
170if isinstance(pointer, list):
171# Here we will split the TF weights
172assert len(pointer) == array.shape[0]
173for i, p_i in enumerate(pointer):
174arr_i = array[i, ...]
175try:
176assert p_i.shape == arr_i.shape
177except AssertionError as e:
178e.args += (p_i.shape, arr_i.shape)
179raise
180logger.info("Initialize PyTorch weight {} for layer {}".format(name, i))
181p_i.data = torch.from_numpy(arr_i)
182else:
183try:
184assert pointer.shape == array.shape
185except AssertionError as e:
186e.args += (pointer.shape, array.shape)
187raise
188logger.info("Initialize PyTorch weight {}".format(name))
189pointer.data = torch.from_numpy(array)
190tf_weights.pop(name, None)
191tf_weights.pop(name + "/Adam", None)
192tf_weights.pop(name + "/Adam_1", None)
193
194logger.info("Weights not copied to PyTorch model: {}".format(", ".join(tf_weights.keys())))
195return model
196
197
198ACT2FN = {"gelu": gelu_new, "relu": torch.nn.functional.relu, "swish": swish}
199
200
201XLNetLayerNorm = nn.LayerNorm
202
203
204class XLNetRelativeAttention(nn.Module):
205def __init__(self, config):
206super().__init__()
207
208if config.d_model % config.n_head != 0:
209raise ValueError(
210"The hidden size (%d) is not a multiple of the number of attention "
211"heads (%d)" % (config.d_model, config.n_head)
212)
213
214self.n_head = config.n_head
215self.d_head = config.d_head
216self.d_model = config.d_model
217self.scale = 1 / (config.d_head ** 0.5)
218
219self.q = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
220self.k = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
221self.v = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
222self.o = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
223self.r = nn.Parameter(torch.FloatTensor(config.d_model, self.n_head, self.d_head))
224
225self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
226self.r_s_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
227self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
228self.seg_embed = nn.Parameter(torch.FloatTensor(2, self.n_head, self.d_head))
229
230self.layer_norm = XLNetLayerNorm(config.d_model, eps=config.layer_norm_eps)
231self.dropout = nn.Dropout(config.dropout)
232
233def prune_heads(self, heads):
234raise NotImplementedError
235
236@staticmethod
237def rel_shift(x, klen=-1):
238"""perform relative shift to form the relative attention score."""
239x_size = x.shape
240
241x = x.reshape(x_size[1], x_size[0], x_size[2], x_size[3])
242x = x[1:, ...]
243x = x.reshape(x_size[0], x_size[1] - 1, x_size[2], x_size[3])
244# x = x[:, 0:klen, :, :]
245x = torch.index_select(x, 1, torch.arange(klen, device=x.device, dtype=torch.long))
246
247return x
248
249@staticmethod
250def rel_shift_bnij(x, klen=-1):
251x_size = x.shape
252
253x = x.reshape(x_size[0], x_size[1], x_size[3], x_size[2])
254x = x[:, :, 1:, :]
255x = x.reshape(x_size[0], x_size[1], x_size[2], x_size[3] - 1)
256# Note: the tensor-slice form was faster in my testing than torch.index_select
257# However, tracing doesn't like the nature of the slice, and if klen changes
258# during the run then it'll fail, whereas index_select will be fine.
259x = torch.index_select(x, 3, torch.arange(klen, device=x.device, dtype=torch.long))
260# x = x[:, :, :, :klen]
261
262return x
263
264def rel_attn_core(
265self,
266q_head,
267k_head_h,
268v_head_h,
269k_head_r,
270seg_mat=None,
271attn_mask=None,
272head_mask=None,
273output_attentions=False,
274):
275"""Core relative positional attention operations."""
276
277# content based attention score
278ac = torch.einsum("ibnd,jbnd->bnij", q_head + self.r_w_bias, k_head_h)
279
280# position based attention score
281bd = torch.einsum("ibnd,jbnd->bnij", q_head + self.r_r_bias, k_head_r)
282bd = self.rel_shift_bnij(bd, klen=ac.shape[3])
283
284# segment based attention score
285if seg_mat is None:
286ef = 0
287else:
288ef = torch.einsum("ibnd,snd->ibns", q_head + self.r_s_bias, self.seg_embed)
289ef = torch.einsum("ijbs,ibns->bnij", seg_mat, ef)
290
291# merge attention scores and perform masking
292attn_score = (ac + bd + ef) * self.scale
293if attn_mask is not None:
294# attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
295if attn_mask.dtype == torch.float16:
296attn_score = attn_score - 65500 * torch.einsum("ijbn->bnij", attn_mask)
297else:
298attn_score = attn_score - 1e30 * torch.einsum("ijbn->bnij", attn_mask)
299
300# attention probability
301attn_prob = F.softmax(attn_score, dim=3)
302attn_prob = self.dropout(attn_prob)
303
304# Mask heads if we want to
305if head_mask is not None:
306attn_prob = attn_prob * torch.einsum("ijbn->bnij", head_mask)
307
308# attention output
309attn_vec = torch.einsum("bnij,jbnd->ibnd", attn_prob, v_head_h)
310
311if output_attentions:
312return attn_vec, torch.einsum("bnij->ijbn", attn_prob)
313
314return attn_vec
315
316def post_attention(self, h, attn_vec, residual=True):
317"""Post-attention processing."""
318# post-attention projection (back to `d_model`)
319attn_out = torch.einsum("ibnd,hnd->ibh", attn_vec, self.o)
320
321attn_out = self.dropout(attn_out)
322if residual:
323attn_out = attn_out + h
324output = self.layer_norm(attn_out)
325
326return output
327
328def forward(
329self,
330h,
331g,
332attn_mask_h,
333attn_mask_g,
334r,
335seg_mat,
336mems=None,
337target_mapping=None,
338head_mask=None,
339output_attentions=False,
340):
341if g is not None:
342# Two-stream attention with relative positional encoding.
343# content based attention score
344if mems is not None and mems.dim() > 1:
345cat = torch.cat([mems, h], dim=0)
346else:
347cat = h
348
349# content-based key head
350k_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.k)
351
352# content-based value head
353v_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.v)
354
355# position-based key head
356k_head_r = torch.einsum("ibh,hnd->ibnd", r, self.r)
357
358# h-stream
359# content-stream query head
360q_head_h = torch.einsum("ibh,hnd->ibnd", h, self.q)
361
362# core attention ops
363attn_vec_h = self.rel_attn_core(
364q_head_h,
365k_head_h,
366v_head_h,
367k_head_r,
368seg_mat=seg_mat,
369attn_mask=attn_mask_h,
370head_mask=head_mask,
371output_attentions=output_attentions,
372)
373
374if output_attentions:
375attn_vec_h, attn_prob_h = attn_vec_h
376
377# post processing
378output_h = self.post_attention(h, attn_vec_h)
379
380# g-stream
381# query-stream query head
382q_head_g = torch.einsum("ibh,hnd->ibnd", g, self.q)
383
384# core attention ops
385if target_mapping is not None:
386q_head_g = torch.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
387attn_vec_g = self.rel_attn_core(
388q_head_g,
389k_head_h,
390v_head_h,
391k_head_r,
392seg_mat=seg_mat,
393attn_mask=attn_mask_g,
394head_mask=head_mask,
395output_attentions=output_attentions,
396)
397
398if output_attentions:
399attn_vec_g, attn_prob_g = attn_vec_g
400
401attn_vec_g = torch.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping)
402else:
403attn_vec_g = self.rel_attn_core(
404q_head_g,
405k_head_h,
406v_head_h,
407k_head_r,
408seg_mat=seg_mat,
409attn_mask=attn_mask_g,
410head_mask=head_mask,
411output_attentions=output_attentions,
412)
413
414if output_attentions:
415attn_vec_g, attn_prob_g = attn_vec_g
416
417# post processing
418output_g = self.post_attention(g, attn_vec_g)
419
420if output_attentions:
421attn_prob = attn_prob_h, attn_prob_g
422
423else:
424# Multi-head attention with relative positional encoding
425if mems is not None and mems.dim() > 1:
426cat = torch.cat([mems, h], dim=0)
427else:
428cat = h
429
430# content heads
431q_head_h = torch.einsum("ibh,hnd->ibnd", h, self.q)
432k_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.k)
433v_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.v)
434
435# positional heads
436k_head_r = torch.einsum("ibh,hnd->ibnd", r, self.r)
437
438# core attention ops
439attn_vec = self.rel_attn_core(
440q_head_h,
441k_head_h,
442v_head_h,
443k_head_r,
444seg_mat=seg_mat,
445attn_mask=attn_mask_h,
446head_mask=head_mask,
447output_attentions=output_attentions,
448)
449
450if output_attentions:
451attn_vec, attn_prob = attn_vec
452
453# post processing
454output_h = self.post_attention(h, attn_vec)
455output_g = None
456
457outputs = (output_h, output_g)
458if output_attentions:
459outputs = outputs + (attn_prob,)
460return outputs
461
462
463class XLNetFeedForward(nn.Module):
464def __init__(self, config):
465super().__init__()
466self.layer_norm = XLNetLayerNorm(config.d_model, eps=config.layer_norm_eps)
467self.layer_1 = nn.Linear(config.d_model, config.d_inner)
468self.layer_2 = nn.Linear(config.d_inner, config.d_model)
469self.dropout = nn.Dropout(config.dropout)
470if isinstance(config.ff_activation, str):
471self.activation_function = ACT2FN[config.ff_activation]
472else:
473self.activation_function = config.ff_activation
474
475def forward(self, inp):
476output = inp
477output = self.layer_1(output)
478output = self.activation_function(output)
479output = self.dropout(output)
480output = self.layer_2(output)
481output = self.dropout(output)
482output = self.layer_norm(output + inp)
483return output
484
485
486class XLNetLayer(nn.Module):
487def __init__(self, config):
488super().__init__()
489self.rel_attn = XLNetRelativeAttention(config)
490self.ff = XLNetFeedForward(config)
491self.dropout = nn.Dropout(config.dropout)
492
493def forward(
494self,
495output_h,
496output_g,
497attn_mask_h,
498attn_mask_g,
499r,
500seg_mat,
501mems=None,
502target_mapping=None,
503head_mask=None,
504output_attentions=False,
505):
506outputs = self.rel_attn(
507output_h,
508output_g,
509attn_mask_h,
510attn_mask_g,
511r,
512seg_mat,
513mems=mems,
514target_mapping=target_mapping,
515head_mask=head_mask,
516output_attentions=output_attentions,
517)
518output_h, output_g = outputs[:2]
519
520if output_g is not None:
521output_g = self.ff(output_g)
522output_h = self.ff(output_h)
523
524outputs = (output_h, output_g) + outputs[2:] # Add again attentions if there are there
525return outputs
526
527
528class XLNetPreTrainedModel(PreTrainedModel):
529""" An abstract class to handle weights initialization and
530a simple interface for downloading and loading pretrained models.
531"""
532
533config_class = XLNetConfig
534load_tf_weights = load_tf_weights_in_xlnet
535base_model_prefix = "transformer"
536
537def _init_weights(self, module):
538""" Initialize the weights.
539"""
540if isinstance(module, (nn.Linear, nn.Embedding)):
541# Slightly different from the TF version which uses truncated_normal for initialization
542# cf https://github.com/pytorch/pytorch/pull/5617
543module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
544if isinstance(module, nn.Linear) and module.bias is not None:
545module.bias.data.zero_()
546elif isinstance(module, XLNetLayerNorm):
547module.bias.data.zero_()
548module.weight.data.fill_(1.0)
549elif isinstance(module, XLNetRelativeAttention):
550for param in [
551module.q,
552module.k,
553module.v,
554module.o,
555module.r,
556module.r_r_bias,
557module.r_s_bias,
558module.r_w_bias,
559module.seg_embed,
560]:
561param.data.normal_(mean=0.0, std=self.config.initializer_range)
562elif isinstance(module, XLNetModel):
563module.mask_emb.data.normal_(mean=0.0, std=self.config.initializer_range)
564
565
566@dataclass
567class XLNetModelOutput(ModelOutput):
568"""
569Output type of :class:`~transformers.XLNetModel`.
570
571Args:
572last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_predict, hidden_size)`):
573Sequence of hidden-states at the last layer of the model.
574
575``num_predict`` corresponds to ``target_mapping.shape[1]``. If ``target_mapping`` is ``None``, then
576``num_predict`` corresponds to ``sequence_length``.
577mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
578Contains pre-computed hidden-states.
579Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
580should not be passed as input ids as they have already been computed.
581hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
582Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
583of shape :obj:`(batch_size, sequence_length, hidden_size)`.
584
585Hidden-states of the model at the output of each layer plus the initial embedding outputs.
586attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
587Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
588:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
589
590Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
591heads.
592"""
593
594last_hidden_state: torch.FloatTensor
595mems: Optional[List[torch.FloatTensor]] = None
596hidden_states: Optional[Tuple[torch.FloatTensor]] = None
597attentions: Optional[Tuple[torch.FloatTensor]] = None
598
599
600@dataclass
601class XLNetLMHeadModelOutput(ModelOutput):
602"""
603Output type of :class:`~transformers.XLNetLMHeadModel`.
604
605Args:
606loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
607Language modeling loss (for next-token prediction).
608logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_predict, config.vocab_size)`):
609Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
610
611``num_predict`` corresponds to ``target_mapping.shape[1]``. If ``target_mapping`` is ``None``, then
612``num_predict`` corresponds to ``sequence_length``.
613mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
614Contains pre-computed hidden-states.
615Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
616should not be passed as input ids as they have already been computed.
617hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
618Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
619of shape :obj:`(batch_size, sequence_length, hidden_size)`.
620
621Hidden-states of the model at the output of each layer plus the initial embedding outputs.
622attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
623Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
624:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
625
626Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
627heads.
628"""
629
630loss: Optional[torch.FloatTensor] = None
631logits: torch.FloatTensor = None
632mems: Optional[List[torch.FloatTensor]] = None
633hidden_states: Optional[Tuple[torch.FloatTensor]] = None
634attentions: Optional[Tuple[torch.FloatTensor]] = None
635
636
637@dataclass
638class XLNetForSequenceClassificationOutput(ModelOutput):
639"""
640Output type of :class:`~transformers.XLNetForSequenceClassification`.
641
642Args:
643loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
644Classification (or regression if config.num_labels==1) loss.
645logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
646Classification (or regression if config.num_labels==1) scores (before SoftMax).
647mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
648Contains pre-computed hidden-states.
649Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
650should not be passed as input ids as they have already been computed.
651hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
652Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
653of shape :obj:`(batch_size, sequence_length, hidden_size)`.
654
655Hidden-states of the model at the output of each layer plus the initial embedding outputs.
656attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
657Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
658:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
659
660Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
661heads.
662"""
663
664loss: Optional[torch.FloatTensor] = None
665logits: torch.FloatTensor = None
666mems: Optional[List[torch.FloatTensor]] = None
667hidden_states: Optional[Tuple[torch.FloatTensor]] = None
668attentions: Optional[Tuple[torch.FloatTensor]] = None
669
670
671@dataclass
672class XLNetForTokenClassificationOutput(ModelOutput):
673"""
674Output type of :class:`~transformers.XLNetForTokenClassificationOutput`.
675
676Args:
677loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
678Classification loss.
679logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
680Classification scores (before SoftMax).
681mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
682Contains pre-computed hidden-states.
683Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
684should not be passed as input ids as they have already been computed.
685hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
686Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
687of shape :obj:`(batch_size, sequence_length, hidden_size)`.
688
689Hidden-states of the model at the output of each layer plus the initial embedding outputs.
690attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
691Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
692:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
693
694Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
695heads.
696"""
697
698loss: Optional[torch.FloatTensor] = None
699logits: torch.FloatTensor = None
700mems: Optional[List[torch.FloatTensor]] = None
701hidden_states: Optional[Tuple[torch.FloatTensor]] = None
702attentions: Optional[Tuple[torch.FloatTensor]] = None
703
704
705@dataclass
706class XLNetForMultipleChoiceOutput(ModelOutput):
707"""
708Base class for outputs of multiple choice models.
709
710Args:
711loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
712Classification loss.
713logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
714`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
715
716Classification scores (before SoftMax).
717mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
718Contains pre-computed hidden-states.
719Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
720should not be passed as input ids as they have already been computed.
721hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
722Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
723of shape :obj:`(batch_size, sequence_length, hidden_size)`.
724
725Hidden-states of the model at the output of each layer plus the initial embedding outputs.
726attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
727Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
728:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
729
730Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
731heads.
732"""
733
734loss: Optional[torch.FloatTensor] = None
735logits: torch.FloatTensor = None
736mems: Optional[List[torch.FloatTensor]] = None
737hidden_states: Optional[Tuple[torch.FloatTensor]] = None
738attentions: Optional[Tuple[torch.FloatTensor]] = None
739
740
741@dataclass
742class XLNetForQuestionAnsweringSimpleOutput(ModelOutput):
743"""
744Base class for outputs of question answering models.
745
746Args:
747loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
748Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
749start_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
750Span-start scores (before SoftMax).
751end_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
752Span-end scores (before SoftMax).
753mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
754Contains pre-computed hidden-states.
755Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
756should not be passed as input ids as they have already been computed.
757hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
758Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
759of shape :obj:`(batch_size, sequence_length, hidden_size)`.
760
761Hidden-states of the model at the output of each layer plus the initial embedding outputs.
762attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
763Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
764:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
765
766Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
767heads.
768"""
769
770loss: Optional[torch.FloatTensor] = None
771start_logits: torch.FloatTensor = None
772end_logits: torch.FloatTensor = None
773mems: Optional[List[torch.FloatTensor]] = None
774hidden_states: Optional[Tuple[torch.FloatTensor]] = None
775attentions: Optional[Tuple[torch.FloatTensor]] = None
776
777
778@dataclass
779class XLNetForQuestionAnsweringOutput(ModelOutput):
780"""
781Base class for outputs of question answering models using a :obj:`SquadHead`.
782
783Args:
784loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned if both :obj:`start_positions` and :obj:`end_positions` are provided):
785Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
786start_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
787Log probabilities for the top config.start_n_top start token possibilities (beam-search).
788start_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
789Indices for the top config.start_n_top start token possibilities (beam-search).
790end_top_log_probs (``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
791Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
792end_top_index (``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
793Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
794cls_logits (``torch.FloatTensor`` of shape ``(batch_size,)``, `optional`, returned if ``start_positions`` or ``end_positions`` is not provided):
795Log probabilities for the ``is_impossible`` label of the answers.
796mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
797Contains pre-computed hidden-states.
798Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
799should not be passed as input ids as they have already been computed.
800hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
801Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
802of shape :obj:`(batch_size, sequence_length, hidden_size)`.
803
804Hidden-states of the model at the output of each layer plus the initial embedding outputs.
805attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
806Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
807:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
808
809Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
810heads.
811"""
812
813loss: Optional[torch.FloatTensor] = None
814start_top_log_probs: Optional[torch.FloatTensor] = None
815start_top_index: Optional[torch.LongTensor] = None
816end_top_log_probs: Optional[torch.FloatTensor] = None
817end_top_index: Optional[torch.LongTensor] = None
818cls_logits: Optional[torch.FloatTensor] = None
819mems: Optional[List[torch.FloatTensor]] = None
820hidden_states: Optional[Tuple[torch.FloatTensor]] = None
821attentions: Optional[Tuple[torch.FloatTensor]] = None
822
823
824XLNET_START_DOCSTRING = r"""
825
826This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
827Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
828usage and behavior.
829
830Parameters:
831config (:class:`~transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
832Initializing with a config file does not load the weights associated with the model, only the configuration.
833Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
834"""
835
836XLNET_INPUTS_DOCSTRING = r"""
837Args:
838input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
839Indices of input sequence tokens in the vocabulary.
840
841Indices can be obtained using :class:`transformers.BertTokenizer`.
842See :func:`transformers.PreTrainedTokenizer.encode` and
843:func:`transformers.PreTrainedTokenizer.__call__` for details.
844
845`What are input IDs? <../glossary.html#input-ids>`__
846attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
847Mask to avoid performing attention on padding token indices.
848Mask values selected in ``[0, 1]``:
849``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
850
851`What are attention masks? <../glossary.html#attention-mask>`__
852mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
853Contains pre-computed hidden-states as computed by the model
854(see `mems` output below). Can be used to speed up sequential decoding. The token ids which have their mems
855given to this model should not be passed as input ids as they have already been computed.
856`use_cache` has to be set to `True` to make use of `mems`.
857perm_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, sequence_length)`, `optional`, defaults to :obj:`None`):
858Mask to indicate the attention pattern for each input token with values selected in ``[0, 1]``:
859If ``perm_mask[k, i, j] = 0``, i attend to j in batch k;
860if ``perm_mask[k, i, j] = 1``, i does not attend to j in batch k.
861If None, each token attends to all the others (full bidirectional attention).
862Only used during pretraining (to define factorization order) or for sequential decoding (generation).
863target_mapping (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_predict, sequence_length)`, `optional`, defaults to :obj:`None`):
864Mask to indicate the output tokens to use.
865If ``target_mapping[k, i, j] = 1``, the i-th predict in batch k is on the j-th token.
866Only used during pretraining for partial prediction or for sequential decoding (generation).
867token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
868Segment token indices to indicate first and second portions of the inputs.
869Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
870corresponds to a `sentence B` token. The classifier token should be represented by a ``2``.
871
872`What are token type IDs? <../glossary.html#token-type-ids>`_
873input_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
874Mask to avoid performing attention on padding token indices.
875Negative of `attention_mask`, i.e. with 0 for real tokens and 1 for padding.
876Kept for compatibility with the original code base.
877You can only uses one of `input_mask` and `attention_mask`
878Mask values selected in ``[0, 1]``:
879``1`` for tokens that are MASKED, ``0`` for tokens that are NOT MASKED.
880head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
881Mask to nullify selected heads of the self-attention modules.
882Mask values selected in ``[0, 1]``:
883:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
884inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
885Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
886This is useful if you want more control over how to convert `input_ids` indices into associated vectors
887than the model's internal embedding lookup matrix.
888use_cache (:obj:`bool`):
889If `use_cache` is True, `mems` are returned and can be used to speed up decoding (see `mems`). Defaults to `True`.
890output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
891If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
892output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
893If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
894return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`):
895If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a
896plain tuple.
897"""
898
899
900@add_start_docstrings(
901"The bare XLNet Model transformer outputting raw hidden-states without any specific head on top.",
902XLNET_START_DOCSTRING,
903)
904class XLNetModel(XLNetPreTrainedModel):
905def __init__(self, config):
906super().__init__(config)
907
908self.mem_len = config.mem_len
909self.reuse_len = config.reuse_len
910self.d_model = config.d_model
911self.same_length = config.same_length
912self.attn_type = config.attn_type
913self.bi_data = config.bi_data
914self.clamp_len = config.clamp_len
915self.n_layer = config.n_layer
916
917self.word_embedding = nn.Embedding(config.vocab_size, config.d_model)
918self.mask_emb = nn.Parameter(torch.FloatTensor(1, 1, config.d_model))
919self.layer = nn.ModuleList([XLNetLayer(config) for _ in range(config.n_layer)])
920self.dropout = nn.Dropout(config.dropout)
921
922self.init_weights()
923
924def get_input_embeddings(self):
925return self.word_embedding
926
927def set_input_embeddings(self, new_embeddings):
928self.word_embedding = new_embeddings
929
930def _prune_heads(self, heads_to_prune):
931raise NotImplementedError
932
933def create_mask(self, qlen, mlen):
934"""
935Creates causal attention mask. Float mask where 1.0 indicates masked, 0.0 indicates not-masked.
936
937Args:
938qlen: Sequence length
939mlen: Mask length
940
941::
942
943same_length=False: same_length=True:
944<mlen > < qlen > <mlen > < qlen >
945^ [0 0 0 0 0 1 1 1 1] [0 0 0 0 0 1 1 1 1]
946[0 0 0 0 0 0 1 1 1] [1 0 0 0 0 0 1 1 1]
947qlen [0 0 0 0 0 0 0 1 1] [1 1 0 0 0 0 0 1 1]
948[0 0 0 0 0 0 0 0 1] [1 1 1 0 0 0 0 0 1]
949v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0]
950
951"""
952attn_mask = torch.ones([qlen, qlen])
953mask_up = torch.triu(attn_mask, diagonal=1)
954attn_mask_pad = torch.zeros([qlen, mlen])
955ret = torch.cat([attn_mask_pad, mask_up], dim=1)
956if self.same_length:
957mask_lo = torch.tril(attn_mask, diagonal=-1)
958ret = torch.cat([ret[:, :qlen] + mask_lo, ret[:, qlen:]], dim=1)
959
960ret = ret.to(self.device)
961return ret
962
963def cache_mem(self, curr_out, prev_mem):
964# cache hidden states into memory.
965if self.reuse_len is not None and self.reuse_len > 0:
966curr_out = curr_out[: self.reuse_len]
967
968if self.mem_len is None or self.mem_len == 0:
969# If `use_cache` is active but no `mem_len` is defined, the model behaves like GPT-2 at inference time
970# and returns all of the past and current hidden states.
971cutoff = 0
972else:
973# If `use_cache` is active and `mem_len` is defined, the model returns the last `mem_len` hidden
974# states. This is the preferred setting for training and long-form generation.
975cutoff = -self.mem_len
976if prev_mem is None:
977# if `use_cache` is active and `mem_len` is defined, the model
978new_mem = curr_out[cutoff:]
979else:
980new_mem = torch.cat([prev_mem, curr_out], dim=0)[cutoff:]
981
982return new_mem.detach()
983
984@staticmethod
985def positional_embedding(pos_seq, inv_freq, bsz=None):
986sinusoid_inp = torch.einsum("i,d->id", pos_seq, inv_freq)
987pos_emb = torch.cat([torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)], dim=-1)
988pos_emb = pos_emb[:, None, :]
989
990if bsz is not None:
991pos_emb = pos_emb.expand(-1, bsz, -1)
992
993return pos_emb
994
995def relative_positional_encoding(self, qlen, klen, bsz=None):
996# create relative positional encoding.
997freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.float)
998inv_freq = 1 / torch.pow(10000, (freq_seq / self.d_model))
999
1000if self.attn_type == "bi":
1001# beg, end = klen - 1, -qlen
1002beg, end = klen, -qlen
1003elif self.attn_type == "uni":
1004# beg, end = klen - 1, -1
1005beg, end = klen, -1
1006else:
1007raise ValueError("Unknown `attn_type` {}.".format(self.attn_type))
1008
1009if self.bi_data:
1010fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.float)
1011bwd_pos_seq = torch.arange(-beg, -end, 1.0, dtype=torch.float)
1012
1013if self.clamp_len > 0:
1014fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
1015bwd_pos_seq = bwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
1016
1017if bsz is not None:
1018fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz // 2)
1019bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz // 2)
1020else:
1021fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq)
1022bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq)
1023
1024pos_emb = torch.cat([fwd_pos_emb, bwd_pos_emb], dim=1)
1025else:
1026fwd_pos_seq = torch.arange(beg, end, -1.0)
1027if self.clamp_len > 0:
1028fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
1029pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
1030
1031pos_emb = pos_emb.to(self.device)
1032return pos_emb
1033
1034@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1035@add_code_sample_docstrings(
1036tokenizer_class=_TOKENIZER_FOR_DOC,
1037checkpoint="xlnet-base-cased",
1038output_type=XLNetModelOutput,
1039config_class=_CONFIG_FOR_DOC,
1040)
1041def forward(
1042self,
1043input_ids=None,
1044attention_mask=None,
1045mems=None,
1046perm_mask=None,
1047target_mapping=None,
1048token_type_ids=None,
1049input_mask=None,
1050head_mask=None,
1051inputs_embeds=None,
1052use_cache=None,
1053output_attentions=None,
1054output_hidden_states=None,
1055return_dict=None,
1056):
1057output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1058output_hidden_states = (
1059output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1060)
1061return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1062use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
1063
1064# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
1065# but we want a unified interface in the library with the batch size on the first dimension
1066# so we move here the first dimension (batch) to the end
1067if input_ids is not None and inputs_embeds is not None:
1068raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1069elif input_ids is not None:
1070input_ids = input_ids.transpose(0, 1).contiguous()
1071qlen, bsz = input_ids.shape[0], input_ids.shape[1]
1072elif inputs_embeds is not None:
1073inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
1074qlen, bsz = inputs_embeds.shape[0], inputs_embeds.shape[1]
1075else:
1076raise ValueError("You have to specify either input_ids or inputs_embeds")
1077
1078token_type_ids = token_type_ids.transpose(0, 1).contiguous() if token_type_ids is not None else None
1079input_mask = input_mask.transpose(0, 1).contiguous() if input_mask is not None else None
1080attention_mask = attention_mask.transpose(0, 1).contiguous() if attention_mask is not None else None
1081perm_mask = perm_mask.permute(1, 2, 0).contiguous() if perm_mask is not None else None
1082target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None
1083
1084mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0
1085klen = mlen + qlen
1086
1087dtype_float = self.dtype
1088device = self.device
1089
1090# Attention mask
1091# causal attention mask
1092if self.attn_type == "uni":
1093attn_mask = self.create_mask(qlen, mlen)
1094attn_mask = attn_mask[:, :, None, None]
1095elif self.attn_type == "bi":
1096attn_mask = None
1097else:
1098raise ValueError("Unsupported attention type: {}".format(self.attn_type))
1099
1100# data mask: input mask & perm mask
1101assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) "
1102"or attention_mask (uses 0 for padding, added for compatbility with BERT). Please choose one."
1103if input_mask is None and attention_mask is not None:
1104input_mask = 1.0 - attention_mask
1105if input_mask is not None and perm_mask is not None:
1106data_mask = input_mask[None] + perm_mask
1107elif input_mask is not None and perm_mask is None:
1108data_mask = input_mask[None]
1109elif input_mask is None and perm_mask is not None:
1110data_mask = perm_mask
1111else:
1112data_mask = None
1113
1114if data_mask is not None:
1115# all mems can be attended to
1116if mlen > 0:
1117mems_mask = torch.zeros([data_mask.shape[0], mlen, bsz]).to(data_mask)
1118data_mask = torch.cat([mems_mask, data_mask], dim=1)
1119if attn_mask is None:
1120attn_mask = data_mask[:, :, :, None]
1121else:
1122attn_mask += data_mask[:, :, :, None]
1123
1124if attn_mask is not None:
1125attn_mask = (attn_mask > 0).to(dtype_float)
1126
1127if attn_mask is not None:
1128non_tgt_mask = -torch.eye(qlen).to(attn_mask)
1129if mlen > 0:
1130non_tgt_mask = torch.cat([torch.zeros([qlen, mlen]).to(attn_mask), non_tgt_mask], dim=-1)
1131non_tgt_mask = ((attn_mask + non_tgt_mask[:, :, None, None]) > 0).to(attn_mask)
1132else:
1133non_tgt_mask = None
1134
1135# Word embeddings and prepare h & g hidden states
1136if inputs_embeds is not None:
1137word_emb_k = inputs_embeds
1138else:
1139word_emb_k = self.word_embedding(input_ids)
1140output_h = self.dropout(word_emb_k)
1141if target_mapping is not None:
1142word_emb_q = self.mask_emb.expand(target_mapping.shape[0], bsz, -1)
1143# else: # We removed the inp_q input which was same as target mapping
1144# inp_q_ext = inp_q[:, :, None]
1145# word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
1146output_g = self.dropout(word_emb_q)
1147else:
1148output_g = None
1149
1150# Segment embedding
1151if token_type_ids is not None:
1152# Convert `token_type_ids` to one-hot `seg_mat`
1153if mlen > 0:
1154mem_pad = torch.zeros([mlen, bsz], dtype=torch.long, device=device)
1155cat_ids = torch.cat([mem_pad, token_type_ids], dim=0)
1156else:
1157cat_ids = token_type_ids
1158
1159# `1` indicates not in the same segment [qlen x klen x bsz]
1160seg_mat = (token_type_ids[:, None] != cat_ids[None, :]).long()
1161seg_mat = F.one_hot(seg_mat, num_classes=2).to(dtype_float)
1162else:
1163seg_mat = None
1164
1165# Positional encoding
1166pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
1167pos_emb = self.dropout(pos_emb)
1168
1169# Prepare head mask if needed
1170# 1.0 in head_mask indicate we keep the head
1171# attention_probs has shape bsz x n_heads x N x N
1172# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
1173# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
1174if head_mask is not None:
1175if head_mask.dim() == 1:
1176head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)
1177head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
1178elif head_mask.dim() == 2:
1179head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
1180head_mask = head_mask.to(
1181dtype=next(self.parameters()).dtype
1182) # switch to fload if need + fp16 compatibility
1183else:
1184head_mask = [None] * self.n_layer
1185
1186new_mems = ()
1187if mems is None:
1188mems = [None] * len(self.layer)
1189
1190attentions = [] if output_attentions else None
1191hidden_states = [] if output_hidden_states else None
1192for i, layer_module in enumerate(self.layer):
1193if use_cache:
1194# cache new mems
1195new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
1196if output_hidden_states:
1197hidden_states.append((output_h, output_g) if output_g is not None else output_h)
1198
1199outputs = layer_module(
1200output_h,
1201output_g,
1202attn_mask_h=non_tgt_mask,
1203attn_mask_g=attn_mask,
1204r=pos_emb,
1205seg_mat=seg_mat,
1206mems=mems[i],
1207target_mapping=target_mapping,
1208head_mask=head_mask[i],
1209output_attentions=output_attentions,
1210)
1211output_h, output_g = outputs[:2]
1212if output_attentions:
1213attentions.append(outputs[2])
1214
1215# Add last hidden state
1216if output_hidden_states:
1217hidden_states.append((output_h, output_g) if output_g is not None else output_h)
1218
1219output = self.dropout(output_g if output_g is not None else output_h)
1220
1221# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
1222output = output.permute(1, 0, 2).contiguous()
1223
1224# TODO Teven: fix this test to only use use_cache.
1225if not use_cache:
1226new_mems = None
1227
1228if output_hidden_states:
1229if output_g is not None:
1230hidden_states = tuple(h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs)
1231else:
1232hidden_states = tuple(hs.permute(1, 0, 2).contiguous() for hs in hidden_states)
1233
1234if output_attentions:
1235if target_mapping is not None:
1236# when target_mapping is provided, there are 2-tuple of attentions
1237attentions = tuple(
1238tuple(att_stream.permute(2, 3, 0, 1).contiguous() for att_stream in t) for t in attentions
1239)
1240else:
1241attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
1242
1243if not return_dict:
1244return tuple(v for v in [output, new_mems, hidden_states, attentions] if v is not None)
1245
1246return XLNetModelOutput(
1247last_hidden_state=output, mems=new_mems, hidden_states=hidden_states, attentions=attentions
1248)
1249
1250
1251@add_start_docstrings(
1252"""XLNet Model with a language modeling head on top
1253(linear layer with weights tied to the input embeddings). """,
1254XLNET_START_DOCSTRING,
1255)
1256class XLNetLMHeadModel(XLNetPreTrainedModel):
1257def __init__(self, config):
1258super().__init__(config)
1259self.attn_type = config.attn_type
1260self.same_length = config.same_length
1261
1262self.transformer = XLNetModel(config)
1263self.lm_loss = nn.Linear(config.d_model, config.vocab_size, bias=True)
1264
1265self.init_weights()
1266
1267def get_output_embeddings(self):
1268return self.lm_loss
1269
1270def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
1271# Add dummy token at the end (no attention on this one)
1272
1273effective_batch_size = input_ids.shape[0]
1274dummy_token = torch.zeros((effective_batch_size, 1), dtype=torch.long, device=input_ids.device)
1275
1276# At every pass, the attention values for the new token and the two last generated tokens
1277# are computed, the rest is reloaded from the `past` cache. A purely auto-regressive model would have
1278# offset = 1; offset = 2 seems to have slightly better computation.
1279offset = 2
1280
1281if past:
1282input_ids = torch.cat([input_ids[:, -offset:], dummy_token], dim=1)
1283else:
1284input_ids = torch.cat([input_ids, dummy_token], dim=1)
1285
1286# Build permutation mask so that previous tokens don't see last token
1287sequence_length = input_ids.shape[1]
1288perm_mask = torch.zeros(
1289(effective_batch_size, sequence_length, sequence_length), dtype=torch.float, device=input_ids.device
1290)
1291perm_mask[:, :, -1] = 1.0
1292
1293# We'll only predict the last token
1294target_mapping = torch.zeros(
1295(effective_batch_size, 1, sequence_length), dtype=torch.float, device=input_ids.device
1296)
1297target_mapping[0, 0, -1] = 1.0
1298
1299inputs = {
1300"input_ids": input_ids,
1301"perm_mask": perm_mask,
1302"target_mapping": target_mapping,
1303"use_cache": kwargs["use_cache"],
1304}
1305
1306# if past is defined in model kwargs then use it for faster decoding
1307if past:
1308inputs["mems"] = tuple(layer_past[:-offset, :, :] for layer_past in past)
1309
1310return inputs
1311
1312@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1313@replace_return_docstrings(output_type=XLNetLMHeadModelOutput, config_class=_CONFIG_FOR_DOC)
1314def forward(
1315self,
1316input_ids=None,
1317attention_mask=None,
1318mems=None,
1319perm_mask=None,
1320target_mapping=None,
1321token_type_ids=None,
1322input_mask=None,
1323head_mask=None,
1324inputs_embeds=None,
1325labels=None,
1326use_cache=None,
1327output_attentions=None,
1328output_hidden_states=None,
1329return_dict=None,
1330):
1331r"""
1332labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_predict)`, `optional`, defaults to :obj:`None`):
1333Labels for masked language modeling.
1334`num_predict` corresponds to `target_mapping.shape[1]`. If `target_mapping` is `None`, then `num_predict` corresponds to `sequence_length`.
1335The labels should correspond to the masked input words that should be predicted and depends on `target_mapping`. Note in order to perform standard auto-regressive language modeling a `<mask>` token has to be added to the `input_ids` (see `prepare_inputs_for_generation` fn and examples below)
1336Indices are selected in ``[-100, 0, ..., config.vocab_size]``
1337All labels set to ``-100`` are ignored, the loss is only
1338computed for labels in ``[0, ..., config.vocab_size]``
1339
1340Return:
1341
1342Examples::
1343
1344from transformers import XLNetTokenizer, XLNetLMHeadModel
1345import torch
1346
1347tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
1348model = XLNetLMHeadModel.from_pretrained('xlnet-large-cased', return_dict=True)
1349
1350# We show how to setup inputs to predict a next token using a bi-directional context.
1351input_ids = torch.tensor(tokenizer.encode("Hello, my dog is very <mask>", add_special_tokens=False)).unsqueeze(0) # We will predict the masked token
1352perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
1353perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
1354target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float) # Shape [1, 1, seq_length] => let's predict one token
1355target_mapping[0, 0, -1] = 1.0 # Our first (and only) prediction will be the last token of the sequence (the masked token)
1356
1357outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)
1358next_token_logits = outputs[0] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
1359
1360# The same way can the XLNetLMHeadModel be used to be trained by standard auto-regressive language modeling.
1361input_ids = torch.tensor(tokenizer.encode("Hello, my dog is very <mask>", add_special_tokens=False)).unsqueeze(0) # We will predict the masked token
1362labels = torch.tensor(tokenizer.encode("cute", add_special_tokens=False)).unsqueeze(0)
1363assert labels.shape[0] == 1, 'only one word will be predicted'
1364perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
1365perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token as is done in standard auto-regressive lm training
1366target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float) # Shape [1, 1, seq_length] => let's predict one token
1367target_mapping[0, 0, -1] = 1.0 # Our first (and only) prediction will be the last token of the sequence (the masked token)
1368
1369outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping, labels=labels)
1370loss = outputs.loss
1371next_token_logits = outputs.logits # Logits have shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
1372"""
1373return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1374use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
1375
1376transformer_outputs = self.transformer(
1377input_ids,
1378attention_mask=attention_mask,
1379mems=mems,
1380perm_mask=perm_mask,
1381target_mapping=target_mapping,
1382token_type_ids=token_type_ids,
1383input_mask=input_mask,
1384head_mask=head_mask,
1385inputs_embeds=inputs_embeds,
1386use_cache=use_cache,
1387output_attentions=output_attentions,
1388output_hidden_states=output_hidden_states,
1389return_dict=return_dict,
1390)
1391
1392logits = self.lm_loss(transformer_outputs[0])
1393
1394loss = None
1395if labels is not None:
1396# Flatten the tokens
1397loss_fct = CrossEntropyLoss()
1398loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
1399
1400if not return_dict:
1401output = (logits,) + transformer_outputs[1:]
1402return ((loss,) + output) if loss is not None else output
1403
1404return XLNetLMHeadModelOutput(
1405loss=loss,
1406logits=logits,
1407mems=transformer_outputs.mems,
1408hidden_states=transformer_outputs.hidden_states,
1409attentions=transformer_outputs.attentions,
1410)
1411
1412
1413@add_start_docstrings(
1414"""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
1415the pooled output) e.g. for GLUE tasks. """,
1416XLNET_START_DOCSTRING,
1417)
1418class XLNetForSequenceClassification(XLNetPreTrainedModel):
1419def __init__(self, config):
1420super().__init__(config)
1421self.num_labels = config.num_labels
1422
1423self.transformer = XLNetModel(config)
1424self.sequence_summary = SequenceSummary(config)
1425self.logits_proj = nn.Linear(config.d_model, config.num_labels)
1426
1427self.init_weights()
1428
1429@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1430@add_code_sample_docstrings(
1431tokenizer_class=_TOKENIZER_FOR_DOC,
1432checkpoint="xlnet-base-cased",
1433output_type=XLNetForSequenceClassificationOutput,
1434config_class=_CONFIG_FOR_DOC,
1435)
1436def forward(
1437self,
1438input_ids=None,
1439attention_mask=None,
1440mems=None,
1441perm_mask=None,
1442target_mapping=None,
1443token_type_ids=None,
1444input_mask=None,
1445head_mask=None,
1446inputs_embeds=None,
1447labels=None,
1448use_cache=None,
1449output_attentions=None,
1450output_hidden_states=None,
1451return_dict=None,
1452):
1453r"""
1454labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`)
1455Labels for computing the sequence classification/regression loss.
1456Indices should be in ``[0, ..., config.num_labels - 1]``.
1457If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
1458If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
1459"""
1460return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1461use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
1462
1463transformer_outputs = self.transformer(
1464input_ids,
1465attention_mask=attention_mask,
1466mems=mems,
1467perm_mask=perm_mask,
1468target_mapping=target_mapping,
1469token_type_ids=token_type_ids,
1470input_mask=input_mask,
1471head_mask=head_mask,
1472inputs_embeds=inputs_embeds,
1473use_cache=use_cache,
1474output_attentions=output_attentions,
1475output_hidden_states=output_hidden_states,
1476return_dict=return_dict,
1477)
1478output = transformer_outputs[0]
1479
1480output = self.sequence_summary(output)
1481logits = self.logits_proj(output)
1482
1483loss = None
1484if labels is not None:
1485if self.num_labels == 1:
1486# We are doing regression
1487loss_fct = MSELoss()
1488loss = loss_fct(logits.view(-1), labels.view(-1))
1489else:
1490loss_fct = CrossEntropyLoss()
1491loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1492
1493if not return_dict:
1494output = (logits,) + transformer_outputs[1:]
1495return ((loss,) + output) if loss is not None else output
1496
1497return XLNetForSequenceClassificationOutput(
1498loss=loss,
1499logits=logits,
1500mems=transformer_outputs.mems,
1501hidden_states=transformer_outputs.hidden_states,
1502attentions=transformer_outputs.attentions,
1503)
1504
1505
1506@add_start_docstrings(
1507"""XLNet Model with a token classification head on top (a linear layer on top of
1508the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
1509XLNET_START_DOCSTRING,
1510)
1511class XLNetForTokenClassification(XLNetPreTrainedModel):
1512def __init__(self, config):
1513super().__init__(config)
1514self.num_labels = config.num_labels
1515
1516self.transformer = XLNetModel(config)
1517self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1518
1519self.init_weights()
1520
1521@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1522@add_code_sample_docstrings(
1523tokenizer_class=_TOKENIZER_FOR_DOC,
1524checkpoint="xlnet-base-cased",
1525output_type=XLNetForTokenClassificationOutput,
1526config_class=_CONFIG_FOR_DOC,
1527)
1528def forward(
1529self,
1530input_ids=None,
1531attention_mask=None,
1532mems=None,
1533perm_mask=None,
1534target_mapping=None,
1535token_type_ids=None,
1536input_mask=None,
1537head_mask=None,
1538inputs_embeds=None,
1539labels=None,
1540use_cache=None,
1541output_attentions=None,
1542output_hidden_states=None,
1543return_dict=None,
1544):
1545r"""
1546labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1547Labels for computing the multiple choice classification loss.
1548Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
1549of the input tensors. (see `input_ids` above)
1550"""
1551return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1552use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
1553
1554outputs = self.transformer(
1555input_ids,
1556attention_mask=attention_mask,
1557mems=mems,
1558perm_mask=perm_mask,
1559target_mapping=target_mapping,
1560token_type_ids=token_type_ids,
1561input_mask=input_mask,
1562head_mask=head_mask,
1563inputs_embeds=inputs_embeds,
1564use_cache=use_cache,
1565output_attentions=output_attentions,
1566output_hidden_states=output_hidden_states,
1567return_dict=return_dict,
1568)
1569
1570sequence_output = outputs[0]
1571
1572logits = self.classifier(sequence_output)
1573
1574loss = None
1575if labels is not None:
1576loss_fct = CrossEntropyLoss()
1577# Only keep active parts of the loss
1578if attention_mask is not None:
1579active_loss = attention_mask.view(-1) == 1
1580active_logits = logits.view(-1, self.num_labels)
1581active_labels = torch.where(
1582active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
1583)
1584loss = loss_fct(active_logits, active_labels)
1585else:
1586loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1587
1588if not return_dict:
1589output = (logits,) + outputs[1:]
1590return ((loss,) + output) if loss is not None else output
1591
1592return XLNetForTokenClassificationOutput(
1593loss=loss,
1594logits=logits,
1595mems=outputs.mems,
1596hidden_states=outputs.hidden_states,
1597attentions=outputs.attentions,
1598)
1599
1600
1601@add_start_docstrings(
1602"""XLNet Model with a multiple choice classification head on top (a linear layer on top of
1603the pooled output and a softmax) e.g. for RACE/SWAG tasks. """,
1604XLNET_START_DOCSTRING,
1605)
1606class XLNetForMultipleChoice(XLNetPreTrainedModel):
1607def __init__(self, config):
1608super().__init__(config)
1609
1610self.transformer = XLNetModel(config)
1611self.sequence_summary = SequenceSummary(config)
1612self.logits_proj = nn.Linear(config.d_model, 1)
1613
1614self.init_weights()
1615
1616@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
1617@add_code_sample_docstrings(
1618tokenizer_class=_TOKENIZER_FOR_DOC,
1619checkpoint="xlnet-base-cased",
1620output_type=XLNetForMultipleChoiceOutput,
1621config_class=_CONFIG_FOR_DOC,
1622)
1623def forward(
1624self,
1625input_ids=None,
1626token_type_ids=None,
1627input_mask=None,
1628attention_mask=None,
1629mems=None,
1630perm_mask=None,
1631target_mapping=None,
1632head_mask=None,
1633inputs_embeds=None,
1634labels=None,
1635use_cache=None,
1636output_attentions=None,
1637output_hidden_states=None,
1638return_dict=None,
1639):
1640r"""
1641labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1642Labels for computing the multiple choice classification loss.
1643Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
1644of the input tensors. (see `input_ids` above)
1645"""
1646return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1647use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
1648num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1649
1650flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1651flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1652flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1653flat_input_mask = input_mask.view(-1, input_mask.size(-1)) if input_mask is not None else None
1654flat_inputs_embeds = (
1655inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1656if inputs_embeds is not None
1657else None
1658)
1659
1660transformer_outputs = self.transformer(
1661flat_input_ids,
1662token_type_ids=flat_token_type_ids,
1663input_mask=flat_input_mask,
1664attention_mask=flat_attention_mask,
1665mems=mems,
1666perm_mask=perm_mask,
1667target_mapping=target_mapping,
1668head_mask=head_mask,
1669inputs_embeds=flat_inputs_embeds,
1670use_cache=use_cache,
1671output_attentions=output_attentions,
1672output_hidden_states=output_hidden_states,
1673return_dict=return_dict,
1674)
1675
1676output = transformer_outputs[0]
1677
1678output = self.sequence_summary(output)
1679logits = self.logits_proj(output)
1680reshaped_logits = logits.view(-1, num_choices)
1681
1682loss = None
1683if labels is not None:
1684loss_fct = CrossEntropyLoss()
1685loss = loss_fct(reshaped_logits, labels.view(-1))
1686
1687if not return_dict:
1688output = (reshaped_logits,) + transformer_outputs[1:]
1689return ((loss,) + output) if loss is not None else output
1690
1691return XLNetForMultipleChoiceOutput(
1692loss=loss,
1693logits=reshaped_logits,
1694mems=transformer_outputs.mems,
1695hidden_states=transformer_outputs.hidden_states,
1696attentions=transformer_outputs.attentions,
1697)
1698
1699
1700@add_start_docstrings(
1701"""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
1702the hidden-states output to compute `span start logits` and `span end logits`). """,
1703XLNET_START_DOCSTRING,
1704)
1705class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
1706def __init__(self, config):
1707super().__init__(config)
1708self.num_labels = config.num_labels
1709
1710self.transformer = XLNetModel(config)
1711self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1712
1713self.init_weights()
1714
1715@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1716@add_code_sample_docstrings(
1717tokenizer_class=_TOKENIZER_FOR_DOC,
1718checkpoint="xlnet-base-cased",
1719output_type=XLNetForQuestionAnsweringSimpleOutput,
1720config_class=_CONFIG_FOR_DOC,
1721)
1722def forward(
1723self,
1724input_ids=None,
1725attention_mask=None,
1726mems=None,
1727perm_mask=None,
1728target_mapping=None,
1729token_type_ids=None,
1730input_mask=None,
1731head_mask=None,
1732inputs_embeds=None,
1733start_positions=None,
1734end_positions=None,
1735use_cache=None,
1736output_attentions=None,
1737output_hidden_states=None,
1738return_dict=None,
1739):
1740r"""
1741start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1742Labels for position (index) of the start of the labelled span for computing the token classification loss.
1743Positions are clamped to the length of the sequence (`sequence_length`).
1744Position outside of the sequence are not taken into account for computing the loss.
1745end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1746Labels for position (index) of the end of the labelled span for computing the token classification loss.
1747Positions are clamped to the length of the sequence (`sequence_length`).
1748Position outside of the sequence are not taken into account for computing the loss.
1749"""
1750return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1751use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
1752
1753outputs = self.transformer(
1754input_ids,
1755attention_mask=attention_mask,
1756mems=mems,
1757perm_mask=perm_mask,
1758target_mapping=target_mapping,
1759token_type_ids=token_type_ids,
1760input_mask=input_mask,
1761head_mask=head_mask,
1762inputs_embeds=inputs_embeds,
1763use_cache=use_cache,
1764output_attentions=output_attentions,
1765output_hidden_states=output_hidden_states,
1766return_dict=return_dict,
1767)
1768
1769sequence_output = outputs[0]
1770
1771logits = self.qa_outputs(sequence_output)
1772start_logits, end_logits = logits.split(1, dim=-1)
1773start_logits = start_logits.squeeze(-1)
1774end_logits = end_logits.squeeze(-1)
1775
1776total_loss = None
1777if start_positions is not None and end_positions is not None:
1778# If we are on multi-GPU, split add a dimension
1779if len(start_positions.size()) > 1:
1780start_positions = start_positions.squeeze(-1)
1781if len(end_positions.size()) > 1:
1782end_positions = end_positions.squeeze(-1)
1783# sometimes the start/end positions are outside our model inputs, we ignore these terms
1784ignored_index = start_logits.size(1)
1785start_positions.clamp_(0, ignored_index)
1786end_positions.clamp_(0, ignored_index)
1787
1788loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1789start_loss = loss_fct(start_logits, start_positions)
1790end_loss = loss_fct(end_logits, end_positions)
1791total_loss = (start_loss + end_loss) / 2
1792
1793if not return_dict:
1794output = (start_logits, end_logits) + outputs[1:]
1795return ((total_loss,) + output) if total_loss is not None else output
1796
1797return XLNetForQuestionAnsweringSimpleOutput(
1798loss=total_loss,
1799start_logits=start_logits,
1800end_logits=end_logits,
1801mems=outputs.mems,
1802hidden_states=outputs.hidden_states,
1803attentions=outputs.attentions,
1804)
1805
1806
1807@add_start_docstrings(
1808"""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
1809the hidden-states output to compute `span start logits` and `span end logits`). """,
1810XLNET_START_DOCSTRING,
1811)
1812class XLNetForQuestionAnswering(XLNetPreTrainedModel):
1813def __init__(self, config):
1814super().__init__(config)
1815self.start_n_top = config.start_n_top
1816self.end_n_top = config.end_n_top
1817
1818self.transformer = XLNetModel(config)
1819self.start_logits = PoolerStartLogits(config)
1820self.end_logits = PoolerEndLogits(config)
1821self.answer_class = PoolerAnswerClass(config)
1822
1823self.init_weights()
1824
1825@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1826@replace_return_docstrings(output_type=XLNetForQuestionAnsweringOutput, config_class=_CONFIG_FOR_DOC)
1827def forward(
1828self,
1829input_ids=None,
1830attention_mask=None,
1831mems=None,
1832perm_mask=None,
1833target_mapping=None,
1834token_type_ids=None,
1835input_mask=None,
1836head_mask=None,
1837inputs_embeds=None,
1838start_positions=None,
1839end_positions=None,
1840is_impossible=None,
1841cls_index=None,
1842p_mask=None,
1843use_cache=None,
1844output_attentions=None,
1845output_hidden_states=None,
1846return_dict=None,
1847):
1848r"""
1849start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1850Labels for position (index) of the start of the labelled span for computing the token classification loss.
1851Positions are clamped to the length of the sequence (`sequence_length`).
1852Position outside of the sequence are not taken into account for computing the loss.
1853end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1854Labels for position (index) of the end of the labelled span for computing the token classification loss.
1855Positions are clamped to the length of the sequence (`sequence_length`).
1856Position outside of the sequence are not taken into account for computing the loss.
1857is_impossible (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
1858Labels whether a question has an answer or no answer (SQuAD 2.0)
1859cls_index (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
1860Labels for position (index) of the classification token to use as input for computing plausibility of the answer.
1861p_mask (``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
1862Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...).
18631.0 means token should be masked. 0.0 mean token is not masked.
1864
1865Returns:
1866
1867Example::
1868
1869>>> from transformers import XLNetTokenizer, XLNetForQuestionAnswering
1870>>> import torch
1871
1872>>> tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
1873>>> model = XLNetForQuestionAnswering.from_pretrained('xlnet-base-cased', return_dict=True)
1874
1875>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
1876>>> start_positions = torch.tensor([1])
1877>>> end_positions = torch.tensor([3])
1878>>> outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
1879
1880>>> loss = outputs.loss
1881"""
1882return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1883use_cache = self.training or (use_cache if use_cache is not None else self.config.use_cache)
1884
1885transformer_outputs = self.transformer(
1886input_ids,
1887attention_mask=attention_mask,
1888mems=mems,
1889perm_mask=perm_mask,
1890target_mapping=target_mapping,
1891token_type_ids=token_type_ids,
1892input_mask=input_mask,
1893head_mask=head_mask,
1894inputs_embeds=inputs_embeds,
1895use_cache=use_cache,
1896output_attentions=output_attentions,
1897output_hidden_states=output_hidden_states,
1898return_dict=return_dict,
1899)
1900hidden_states = transformer_outputs[0]
1901start_logits = self.start_logits(hidden_states, p_mask=p_mask)
1902
1903outputs = transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
1904
1905if start_positions is not None and end_positions is not None:
1906# If we are on multi-GPU, let's remove the dimension added by batch splitting
1907for x in (start_positions, end_positions, cls_index, is_impossible):
1908if x is not None and x.dim() > 1:
1909x.squeeze_(-1)
1910
1911# during training, compute the end logits based on the ground truth of the start position
1912end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
1913
1914loss_fct = CrossEntropyLoss()
1915start_loss = loss_fct(start_logits, start_positions)
1916end_loss = loss_fct(end_logits, end_positions)
1917total_loss = (start_loss + end_loss) / 2
1918
1919if cls_index is not None and is_impossible is not None:
1920# Predict answerability from the representation of CLS and START
1921cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
1922loss_fct_cls = nn.BCEWithLogitsLoss()
1923cls_loss = loss_fct_cls(cls_logits, is_impossible)
1924
1925# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
1926total_loss += cls_loss * 0.5
1927
1928if not return_dict:
1929return (total_loss,) + transformer_outputs[1:]
1930else:
1931return XLNetForQuestionAnsweringOutput(
1932loss=total_loss,
1933mems=transformer_outputs.mems,
1934hidden_states=transformer_outputs.hidden_states,
1935attentions=transformer_outputs.attentions,
1936)
1937
1938else:
1939# during inference, compute the end logits based on beam search
1940bsz, slen, hsz = hidden_states.size()
1941start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
1942
1943start_top_log_probs, start_top_index = torch.topk(
1944start_log_probs, self.start_n_top, dim=-1
1945) # shape (bsz, start_n_top)
1946start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
1947start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
1948start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
1949
1950hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
1951start_states
1952) # shape (bsz, slen, start_n_top, hsz)
1953p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
1954end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
1955end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
1956
1957end_top_log_probs, end_top_index = torch.topk(
1958end_log_probs, self.end_n_top, dim=1
1959) # shape (bsz, end_n_top, start_n_top)
1960end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
1961end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
1962
1963start_states = torch.einsum(
1964"blh,bl->bh", hidden_states, start_log_probs
1965) # get the representation of START as weighted sum of hidden states
1966cls_logits = self.answer_class(
1967hidden_states, start_states=start_states, cls_index=cls_index
1968) # Shape (batch size,): one single `cls_logits` for each sample
1969
1970if not return_dict:
1971outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
1972return outputs + transformer_outputs[1:]
1973else:
1974return XLNetForQuestionAnsweringOutput(
1975start_top_log_probs=start_top_log_probs,
1976start_top_index=start_top_index,
1977end_top_log_probs=end_top_log_probs,
1978end_top_index=end_top_index,
1979cls_logits=cls_logits,
1980mems=transformer_outputs.mems,
1981hidden_states=transformer_outputs.hidden_states,
1982attentions=transformer_outputs.attentions,
1983)
1984