CSS-LM
1595 строк · 74.4 Кб
1# coding=utf-8
2# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
3# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16""" TF 2.0 XLNet model.
17"""
18
19
20import logging
21
22import numpy as np
23import tensorflow as tf
24
25from .configuration_xlnet import XLNetConfig
26from .file_utils import (
27MULTIPLE_CHOICE_DUMMY_INPUTS,
28add_code_sample_docstrings,
29add_start_docstrings,
30add_start_docstrings_to_callable,
31)
32from .modeling_tf_utils import (
33TFCausalLanguageModelingLoss,
34TFMultipleChoiceLoss,
35TFPreTrainedModel,
36TFQuestionAnsweringLoss,
37TFSequenceClassificationLoss,
38TFSequenceSummary,
39TFSharedEmbeddings,
40TFTokenClassificationLoss,
41cast_bool_to_primitive,
42get_initializer,
43keras_serializable,
44shape_list,
45)
46from .tokenization_utils import BatchEncoding
47
48
49logger = logging.getLogger(__name__)
50
51_TOKENIZER_FOR_DOC = "XLNetTokenizer"
52
53TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
54"xlnet-base-cased",
55"xlnet-large-cased",
56# See all XLNet models at https://huggingface.co/models?filter=xlnet
57]
58
59
60def gelu(x):
61""" Implementation of the gelu activation function.
62XLNet is using OpenAI GPT's gelu
63Also see https://arxiv.org/abs/1606.08415
64"""
65cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
66return x * cdf
67
68
69def swish(x):
70return x * tf.sigmoid(x)
71
72
73ACT2FN = {
74"gelu": tf.keras.layers.Activation(gelu),
75"relu": tf.keras.activations.relu,
76"swish": tf.keras.layers.Activation(swish),
77}
78
79
80class TFXLNetRelativeAttention(tf.keras.layers.Layer):
81def __init__(self, config, **kwargs):
82super().__init__(**kwargs)
83
84if config.d_model % config.n_head != 0:
85raise ValueError(
86"The hidden size (%d) is not a multiple of the number of attention "
87"heads (%d)" % (config.d_model, config.n_head)
88)
89
90self.n_head = config.n_head
91self.d_head = config.d_head
92self.d_model = config.d_model
93self.scale = 1 / (config.d_head ** 0.5)
94self.initializer_range = config.initializer_range
95
96self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
97self.dropout = tf.keras.layers.Dropout(config.dropout)
98
99def build(self, input_shape):
100initializer = get_initializer(self.initializer_range)
101self.q = self.add_weight(
102shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="q"
103)
104self.k = self.add_weight(
105shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="k"
106)
107self.v = self.add_weight(
108shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="v"
109)
110self.o = self.add_weight(
111shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="o"
112)
113self.r = self.add_weight(
114shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="r"
115)
116self.r_r_bias = self.add_weight(
117shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_r_bias"
118)
119self.r_s_bias = self.add_weight(
120shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_s_bias"
121)
122self.r_w_bias = self.add_weight(
123shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_w_bias"
124)
125self.seg_embed = self.add_weight(
126shape=(2, self.n_head, self.d_head), initializer=initializer, trainable=True, name="seg_embed"
127)
128super().build(input_shape)
129
130def prune_heads(self, heads):
131raise NotImplementedError
132
133def rel_shift(self, x, klen=-1):
134"""perform relative shift to form the relative attention score."""
135x_size = shape_list(x)
136
137x = tf.reshape(x, (x_size[1], x_size[0], x_size[2], x_size[3]))
138x = x[1:, ...]
139x = tf.reshape(x, (x_size[0], x_size[1] - 1, x_size[2], x_size[3]))
140x = x[:, 0:klen, :, :]
141# x = torch.index_select(x, 1, torch.arange(klen, device=x.device, dtype=torch.long))
142
143return x
144
145def rel_attn_core(self, inputs, training=False):
146"""Core relative positional attention operations."""
147
148q_head, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask, head_mask, output_attentions = inputs
149
150# content based attention score
151ac = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_w_bias, k_head_h)
152
153# position based attention score
154bd = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_r_bias, k_head_r)
155bd = self.rel_shift(bd, klen=shape_list(ac)[1])
156
157# segment based attention score
158if seg_mat is None:
159ef = 0
160else:
161ef = tf.einsum("ibnd,snd->ibns", q_head + self.r_s_bias, self.seg_embed)
162ef = tf.einsum("ijbs,ibns->ijbn", seg_mat, ef)
163
164# merge attention scores and perform masking
165attn_score = (ac + bd + ef) * self.scale
166if attn_mask is not None:
167# attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
168if attn_mask.dtype == tf.float16:
169attn_score = attn_score - 65500 * attn_mask
170else:
171attn_score = attn_score - 1e30 * attn_mask
172
173# attention probability
174attn_prob = tf.nn.softmax(attn_score, axis=1)
175
176attn_prob = self.dropout(attn_prob, training=training)
177
178# Mask heads if we want to
179if head_mask is not None:
180attn_prob = attn_prob * head_mask
181
182# attention output
183attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, v_head_h)
184
185if cast_bool_to_primitive(output_attentions) is True:
186return attn_vec, attn_prob
187
188return attn_vec
189
190def post_attention(self, inputs, residual=True, training=False):
191"""Post-attention processing."""
192# post-attention projection (back to `d_model`)
193h, attn_vec = inputs
194
195attn_out = tf.einsum("ibnd,hnd->ibh", attn_vec, self.o)
196
197attn_out = self.dropout(attn_out, training=training)
198
199if residual:
200attn_out = attn_out + h
201output = self.layer_norm(attn_out)
202
203return output
204
205def call(self, inputs, training=False):
206(h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems, target_mapping, head_mask, output_attentions) = inputs
207
208if g is not None:
209# Two-stream attention with relative positional encoding.
210# content based attention score
211if mems is not None and len(shape_list(mems)) > 1:
212cat = tf.concat([mems, h], axis=0)
213else:
214cat = h
215
216# content-based key head
217k_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.k)
218
219# content-based value head
220v_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.v)
221
222# position-based key head
223k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)
224
225# h-stream
226# content-stream query head
227q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q)
228
229# core attention ops
230attn_vec_h = self.rel_attn_core(
231[q_head_h, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_h, head_mask, output_attentions],
232training=training,
233)
234
235if cast_bool_to_primitive(output_attentions) is True:
236attn_vec_h, attn_prob_h = attn_vec_h
237
238# post processing
239output_h = self.post_attention([h, attn_vec_h], training=training)
240
241# g-stream
242# query-stream query head
243q_head_g = tf.einsum("ibh,hnd->ibnd", g, self.q)
244
245# core attention ops
246if target_mapping is not None:
247q_head_g = tf.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
248attn_vec_g = self.rel_attn_core(
249[q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask, output_attentions],
250training=training,
251)
252
253if cast_bool_to_primitive(output_attentions) is True:
254attn_vec_g, attn_prob_g = attn_vec_g
255
256attn_vec_g = tf.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping)
257else:
258attn_vec_g = self.rel_attn_core(
259[q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask, output_attentions],
260training=training,
261)
262
263if cast_bool_to_primitive(output_attentions) is True:
264attn_vec_g, attn_prob_g = attn_vec_g
265
266# post processing
267output_g = self.post_attention([g, attn_vec_g], training=training)
268
269if cast_bool_to_primitive(output_attentions) is True:
270attn_prob = attn_prob_h, attn_prob_g
271
272else:
273# Multi-head attention with relative positional encoding
274if mems is not None and len(shape_list(mems)) > 1:
275cat = tf.concat([mems, h], axis=0)
276else:
277cat = h
278
279# content heads
280q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q)
281k_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.k)
282v_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.v)
283
284# positional heads
285k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)
286
287# core attention ops
288attn_vec = self.rel_attn_core(
289[q_head_h, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_h, head_mask, output_attentions],
290training=training,
291)
292
293if cast_bool_to_primitive(output_attentions) is True:
294attn_vec, attn_prob = attn_vec
295
296# post processing
297output_h = self.post_attention([h, attn_vec], training=training)
298output_g = None
299
300outputs = (output_h, output_g)
301if cast_bool_to_primitive(output_attentions) is True:
302outputs = outputs + (attn_prob,)
303return outputs
304
305
306class TFXLNetFeedForward(tf.keras.layers.Layer):
307def __init__(self, config, **kwargs):
308super().__init__(**kwargs)
309self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
310self.layer_1 = tf.keras.layers.Dense(
311config.d_inner, kernel_initializer=get_initializer(config.initializer_range), name="layer_1"
312)
313self.layer_2 = tf.keras.layers.Dense(
314config.d_model, kernel_initializer=get_initializer(config.initializer_range), name="layer_2"
315)
316self.dropout = tf.keras.layers.Dropout(config.dropout)
317if isinstance(config.ff_activation, str):
318self.activation_function = ACT2FN[config.ff_activation]
319else:
320self.activation_function = config.ff_activation
321
322def call(self, inp, training=False):
323output = inp
324output = self.layer_1(output)
325output = self.activation_function(output)
326output = self.dropout(output, training=training)
327output = self.layer_2(output)
328output = self.dropout(output, training=training)
329output = self.layer_norm(output + inp)
330return output
331
332
333class TFXLNetLayer(tf.keras.layers.Layer):
334def __init__(self, config, **kwargs):
335super().__init__(**kwargs)
336self.rel_attn = TFXLNetRelativeAttention(config, name="rel_attn")
337self.ff = TFXLNetFeedForward(config, name="ff")
338self.dropout = tf.keras.layers.Dropout(config.dropout)
339
340def call(self, inputs, training=False):
341outputs = self.rel_attn(inputs, training=training)
342output_h, output_g = outputs[:2]
343
344if output_g is not None:
345output_g = self.ff(output_g, training=training)
346output_h = self.ff(output_h, training=training)
347
348outputs = (output_h, output_g) + outputs[2:] # Add again attentions if there are there
349return outputs
350
351
352class TFXLNetLMHead(tf.keras.layers.Layer):
353def __init__(self, config, input_embeddings, **kwargs):
354super().__init__(**kwargs)
355self.vocab_size = config.vocab_size
356# The output weights are the same as the input embeddings, but there is
357# an output-only bias for each token.
358self.input_embeddings = input_embeddings
359
360def build(self, input_shape):
361self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
362super().build(input_shape)
363
364def call(self, hidden_states):
365hidden_states = self.input_embeddings(hidden_states, mode="linear")
366hidden_states = hidden_states + self.bias
367return hidden_states
368
369
370@keras_serializable
371class TFXLNetMainLayer(tf.keras.layers.Layer):
372config_class = XLNetConfig
373
374def __init__(self, config, **kwargs):
375super().__init__(**kwargs)
376self.output_hidden_states = config.output_hidden_states
377self.output_attentions = config.output_attentions
378
379self.mem_len = config.mem_len
380self.reuse_len = config.reuse_len
381self.d_model = config.d_model
382self.same_length = config.same_length
383self.attn_type = config.attn_type
384self.bi_data = config.bi_data
385self.clamp_len = config.clamp_len
386self.n_layer = config.n_layer
387self.use_bfloat16 = config.use_bfloat16
388self.initializer_range = config.initializer_range
389
390self.word_embedding = TFSharedEmbeddings(
391config.vocab_size, config.d_model, initializer_range=config.initializer_range, name="word_embedding"
392)
393self.layer = [TFXLNetLayer(config, name="layer_._{}".format(i)) for i in range(config.n_layer)]
394self.dropout = tf.keras.layers.Dropout(config.dropout)
395
396def get_input_embeddings(self):
397return self.word_embedding
398
399def set_input_embeddings(self, value):
400self.word_embedding.weight = value
401self.word_embedding.vocab_size = value.shape[0]
402
403def build(self, input_shape):
404initializer = get_initializer(self.initializer_range)
405self.mask_emb = self.add_weight(
406shape=(1, 1, self.d_model), initializer=initializer, trainable=True, name="mask_emb"
407)
408
409def _resize_token_embeddings(self, new_num_tokens):
410raise NotImplementedError
411
412def _prune_heads(self, heads_to_prune):
413raise NotImplementedError
414
415def create_mask(self, qlen, mlen, dtype=tf.float32):
416"""
417Creates causal attention mask. Float mask where 1.0 indicates masked, 0.0 indicates not-masked.
418
419Args:
420qlen: TODO Lysandre didn't fill
421mlen: TODO Lysandre didn't fill
422
423::
424
425same_length=False: same_length=True:
426<mlen > < qlen > <mlen > < qlen >
427^ [0 0 0 0 0 1 1 1 1] [0 0 0 0 0 1 1 1 1]
428[0 0 0 0 0 0 1 1 1] [1 0 0 0 0 0 1 1 1]
429qlen [0 0 0 0 0 0 0 1 1] [1 1 0 0 0 0 0 1 1]
430[0 0 0 0 0 0 0 0 1] [1 1 1 0 0 0 0 0 1]
431v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0]
432
433"""
434attn_mask = tf.ones([qlen, qlen], dtype=dtype)
435mask_u = tf.matrix_band_part(attn_mask, 0, -1)
436mask_dia = tf.matrix_band_part(attn_mask, 0, 0)
437attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype)
438ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
439if self.same_length:
440mask_l = tf.matrix_band_part(attn_mask, -1, 0)
441ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1)
442return ret
443
444def cache_mem(self, curr_out, prev_mem):
445"""cache hidden states into memory."""
446if self.reuse_len is not None and self.reuse_len > 0:
447curr_out = curr_out[: self.reuse_len]
448
449if prev_mem is None:
450new_mem = curr_out[-self.mem_len :]
451else:
452new_mem = tf.concat([prev_mem, curr_out], 0)[-self.mem_len :]
453
454return tf.stop_gradient(new_mem)
455
456@staticmethod
457def positional_embedding(pos_seq, inv_freq, bsz=None):
458sinusoid_inp = tf.einsum("i,d->id", pos_seq, inv_freq)
459pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], axis=-1)
460pos_emb = pos_emb[:, None, :]
461
462if bsz is not None:
463pos_emb = tf.tile(pos_emb, [1, bsz, 1])
464
465return pos_emb
466
467def relative_positional_encoding(self, qlen, klen, bsz=None, dtype=None):
468"""create relative positional encoding."""
469freq_seq = tf.range(0, self.d_model, 2.0)
470if dtype is not None and dtype != tf.float32:
471freq_seq = tf.cast(freq_seq, dtype=dtype)
472inv_freq = 1 / (10000 ** (freq_seq / self.d_model))
473
474if self.attn_type == "bi":
475# beg, end = klen - 1, -qlen
476beg, end = klen, -qlen
477elif self.attn_type == "uni":
478# beg, end = klen - 1, -1
479beg, end = klen, -1
480else:
481raise ValueError("Unknown `attn_type` {}.".format(self.attn_type))
482
483if self.bi_data:
484fwd_pos_seq = tf.range(beg, end, -1.0)
485bwd_pos_seq = tf.range(-beg, -end, 1.0)
486
487if dtype is not None and dtype != tf.float32:
488fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
489bwd_pos_seq = tf.cast(bwd_pos_seq, dtype=dtype)
490
491if self.clamp_len > 0:
492fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len)
493bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -self.clamp_len, self.clamp_len)
494
495if bsz is not None:
496# With bi_data, the batch size should be divisible by 2.
497assert bsz % 2 == 0
498fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz // 2)
499bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz // 2)
500else:
501fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq)
502bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq)
503
504pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1)
505else:
506fwd_pos_seq = tf.range(beg, end, -1.0)
507if dtype is not None and dtype != tf.float32:
508fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
509if self.clamp_len > 0:
510fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len)
511pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
512
513return pos_emb
514
515def call(
516self,
517inputs,
518attention_mask=None,
519mems=None,
520perm_mask=None,
521target_mapping=None,
522token_type_ids=None,
523input_mask=None,
524head_mask=None,
525inputs_embeds=None,
526use_cache=True,
527output_attentions=None,
528output_hidden_states=None,
529training=False,
530):
531if isinstance(inputs, (tuple, list)):
532input_ids = inputs[0]
533attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
534mems = inputs[2] if len(inputs) > 2 else mems
535perm_mask = inputs[3] if len(inputs) > 3 else perm_mask
536target_mapping = inputs[4] if len(inputs) > 4 else target_mapping
537token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids
538input_mask = inputs[6] if len(inputs) > 6 else input_mask
539head_mask = inputs[7] if len(inputs) > 7 else head_mask
540inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
541use_cache = inputs[9] if len(inputs) > 9 else use_cache
542output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
543output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
544assert len(inputs) <= 12, "Too many inputs."
545elif isinstance(inputs, (dict, BatchEncoding)):
546input_ids = inputs.get("input_ids")
547attention_mask = inputs.get("attention_mask", attention_mask)
548mems = inputs.get("mems", mems)
549perm_mask = inputs.get("perm_mask", perm_mask)
550target_mapping = inputs.get("target_mapping", target_mapping)
551token_type_ids = inputs.get("token_type_ids", token_type_ids)
552input_mask = inputs.get("input_mask", input_mask)
553head_mask = inputs.get("head_mask", head_mask)
554inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
555use_cache = inputs.get("use_cache", use_cache)
556output_attentions = inputs.get("output_attentions", output_attentions)
557output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
558assert len(inputs) <= 12, "Too many inputs."
559else:
560input_ids = inputs
561
562output_attentions = output_attentions if output_attentions is not None else self.output_attentions
563output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
564
565# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
566# but we want a unified interface in the library with the batch size on the first dimension
567# so we move here the first dimension (batch) to the end
568
569if input_ids is not None and inputs_embeds is not None:
570raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
571elif input_ids is not None:
572input_ids = tf.transpose(input_ids, perm=(1, 0))
573qlen, bsz = shape_list(input_ids)[:2]
574elif inputs_embeds is not None:
575inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2))
576qlen, bsz = shape_list(inputs_embeds)[:2]
577else:
578raise ValueError("You have to specify either input_ids or inputs_embeds")
579
580token_type_ids = tf.transpose(token_type_ids, perm=(1, 0)) if token_type_ids is not None else None
581input_mask = tf.transpose(input_mask, perm=(1, 0)) if input_mask is not None else None
582attention_mask = tf.transpose(attention_mask, perm=(1, 0)) if attention_mask is not None else None
583perm_mask = tf.transpose(perm_mask, perm=(1, 2, 0)) if perm_mask is not None else None
584target_mapping = tf.transpose(target_mapping, perm=(1, 2, 0)) if target_mapping is not None else None
585
586mlen = shape_list(mems[0])[0] if mems is not None and mems[0] is not None else 0
587klen = mlen + qlen
588
589dtype_float = tf.bfloat16 if self.use_bfloat16 else tf.float32
590
591# Attention mask
592# causal attention mask
593if self.attn_type == "uni":
594attn_mask = self.create_mask(qlen, mlen)
595attn_mask = attn_mask[:, :, None, None]
596elif self.attn_type == "bi":
597attn_mask = None
598else:
599raise ValueError("Unsupported attention type: {}".format(self.attn_type))
600
601# data mask: input mask & perm mask
602assert input_mask is None or attention_mask is None, (
603"You can only use one of input_mask (uses 1 for padding) "
604"or attention_mask (uses 0 for padding, added for compatbility with BERT). Please choose one."
605)
606if input_mask is None and attention_mask is not None:
607input_mask = 1.0 - tf.cast(attention_mask, dtype=dtype_float)
608if input_mask is not None and perm_mask is not None:
609data_mask = input_mask[None] + perm_mask
610elif input_mask is not None and perm_mask is None:
611data_mask = input_mask[None]
612elif input_mask is None and perm_mask is not None:
613data_mask = perm_mask
614else:
615data_mask = None
616
617if data_mask is not None:
618# all mems can be attended to
619if mlen > 0:
620mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz], dtype=dtype_float)
621data_mask = tf.concat([mems_mask, data_mask], axis=1)
622if attn_mask is None:
623attn_mask = data_mask[:, :, :, None]
624else:
625attn_mask += data_mask[:, :, :, None]
626
627if attn_mask is not None:
628attn_mask = tf.cast(attn_mask > 0, dtype=dtype_float)
629
630if attn_mask is not None:
631non_tgt_mask = -tf.eye(qlen, dtype=dtype_float)
632if mlen > 0:
633non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=dtype_float), non_tgt_mask], axis=-1)
634non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=dtype_float)
635else:
636non_tgt_mask = None
637
638# Word embeddings and prepare h & g hidden states
639if inputs_embeds is not None:
640word_emb_k = inputs_embeds
641else:
642word_emb_k = self.word_embedding(input_ids)
643output_h = self.dropout(word_emb_k, training=training)
644if target_mapping is not None:
645word_emb_q = tf.tile(self.mask_emb, [shape_list(target_mapping)[0], bsz, 1])
646# else: # We removed the inp_q input which was same as target mapping
647# inp_q_ext = inp_q[:, :, None]
648# word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
649output_g = self.dropout(word_emb_q, training=training)
650else:
651output_g = None
652
653# Segment embedding
654if token_type_ids is not None:
655# Convert `token_type_ids` to one-hot `seg_mat`
656if mlen > 0:
657mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
658cat_ids = tf.concat([mem_pad, token_type_ids], 0)
659else:
660cat_ids = token_type_ids
661
662# `1` indicates not in the same segment [qlen x klen x bsz]
663seg_mat = tf.cast(tf.logical_not(tf.equal(token_type_ids[:, None], cat_ids[None, :])), tf.int32)
664seg_mat = tf.one_hot(seg_mat, 2, dtype=dtype_float)
665else:
666seg_mat = None
667
668# Positional encoding
669pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz, dtype=dtype_float)
670pos_emb = self.dropout(pos_emb, training=training)
671
672# Prepare head mask if needed
673# 1.0 in head_mask indicate we keep the head
674# attention_probs has shape bsz x n_heads x N x N
675# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
676# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
677if head_mask is not None:
678raise NotImplementedError
679else:
680head_mask = [None] * self.n_layer
681
682new_mems = ()
683if mems is None:
684mems = [None] * len(self.layer)
685
686attentions = []
687hidden_states = []
688for i, layer_module in enumerate(self.layer):
689# cache new mems
690if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
691new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
692if cast_bool_to_primitive(output_hidden_states) is True:
693hidden_states.append((output_h, output_g) if output_g is not None else output_h)
694
695outputs = layer_module(
696[
697output_h,
698output_g,
699non_tgt_mask,
700attn_mask,
701pos_emb,
702seg_mat,
703mems[i],
704target_mapping,
705head_mask[i],
706output_attentions,
707],
708training=training,
709)
710output_h, output_g = outputs[:2]
711if cast_bool_to_primitive(output_attentions) is True:
712attentions.append(outputs[2])
713
714# Add last hidden state
715if cast_bool_to_primitive(output_hidden_states) is True:
716hidden_states.append((output_h, output_g) if output_g is not None else output_h)
717
718output = self.dropout(output_g if output_g is not None else output_h, training=training)
719
720# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
721outputs = (tf.transpose(output, perm=(1, 0, 2)),)
722
723if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
724outputs = outputs + (new_mems,)
725
726if cast_bool_to_primitive(output_hidden_states) is True:
727if output_g is not None:
728hidden_states = tuple(tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs)
729else:
730hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states)
731outputs = outputs + (hidden_states,)
732if cast_bool_to_primitive(output_attentions) is True:
733attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
734outputs = outputs + (attentions,)
735
736return outputs # outputs, (new_mems), (hidden_states), (attentions)
737
738
739class TFXLNetPreTrainedModel(TFPreTrainedModel):
740""" An abstract class to handle weights initialization and
741a simple interface for downloading and loading pretrained models.
742"""
743
744config_class = XLNetConfig
745base_model_prefix = "transformer"
746
747
748XLNET_START_DOCSTRING = r"""
749
750.. note::
751
752TF 2.0 models accepts two formats as inputs:
753
754- having all inputs as keyword arguments (like PyTorch models), or
755- having all inputs as a list, tuple or dict in the first positional arguments.
756
757This second option is useful when using :obj:`tf.keras.Model.fit()` method which currently requires having
758all the tensors in the first argument of the model call function: :obj:`model(inputs)`.
759
760If you choose this second option, there are three possibilities you can use to gather all the input Tensors
761in the first positional argument :
762
763- a single Tensor with input_ids only and nothing else: :obj:`model(inputs_ids)`
764- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
765:obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
766- a dictionary with one or several input Tensors associated to the input names given in the docstring:
767:obj:`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
768
769Parameters:
770config (:class:`~transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
771Initializing with a config file does not load the weights associated with the model, only the configuration.
772Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
773"""
774
775XLNET_INPUTS_DOCSTRING = r"""
776Args:
777input_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`):
778Indices of input sequence tokens in the vocabulary.
779
780Indices can be obtained using :class:`transformers.XLNetTokenizer`.
781See :func:`transformers.PreTrainedTokenizer.encode` and
782:func:`transformers.PreTrainedTokenizer.__call__` for details.
783
784`What are input IDs? <../glossary.html#input-ids>`__
785attention_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
786Mask to avoid performing attention on padding token indices.
787Mask values selected in ``[0, 1]``:
788``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
789
790`What are attention masks? <../glossary.html#attention-mask>`__
791mems (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers`):
792Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
793(see `mems` output below). Can be used to speed up sequential decoding. The token ids which have their mems
794given to this model should not be passed as input ids as they have already been computed.
795perm_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length, sequence_length)`, `optional`, defaults to :obj:`None`):
796Mask to indicate the attention pattern for each input token with values selected in ``[0, 1]``:
797If ``perm_mask[k, i, j] = 0``, i attend to j in batch k;
798if ``perm_mask[k, i, j] = 1``, i does not attend to j in batch k.
799If None, each token attends to all the others (full bidirectional attention).
800Only used during pretraining (to define factorization order) or for sequential decoding (generation).
801target_mapping (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, num_predict, sequence_length)`, `optional`, defaults to :obj:`None`):
802Mask to indicate the output tokens to use.
803If ``target_mapping[k, i, j] = 1``, the i-th predict in batch k is on the j-th token.
804Only used during pretraining for partial prediction or for sequential decoding (generation).
805token_type_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
806Segment token indices to indicate first and second portions of the inputs.
807Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
808corresponds to a `sentence B` token
809
810`What are token type IDs? <../glossary.html#token-type-ids>`_
811input_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
812Mask to avoid performing attention on padding token indices.
813Negative of `attention_mask`, i.e. with 0 for real tokens and 1 for padding.
814Kept for compatibility with the original code base.
815You can only uses one of `input_mask` and `attention_mask`
816Mask values selected in ``[0, 1]``:
817``1`` for tokens that are MASKED, ``0`` for tokens that are NOT MASKED.
818head_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
819Mask to nullify selected heads of the self-attention modules.
820Mask values selected in ``[0, 1]``:
821:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
822inputs_embeds (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
823Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
824This is useful if you want more control over how to convert `input_ids` indices into associated vectors
825than the model's internal embedding lookup matrix.
826use_cache (:obj:`bool`):
827If `use_cache` is True, `mems` are returned and can be used to speed up decoding (see `mems`). Defaults to `True`.
828output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
829If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
830"""
831
832
833@add_start_docstrings(
834"The bare XLNet Model transformer outputing raw hidden-states without any specific head on top.",
835XLNET_START_DOCSTRING,
836)
837class TFXLNetModel(TFXLNetPreTrainedModel):
838def __init__(self, config, *inputs, **kwargs):
839super().__init__(config, *inputs, **kwargs)
840self.transformer = TFXLNetMainLayer(config, name="transformer")
841
842@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
843@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlnet-base-cased")
844def call(self, inputs, **kwargs):
845r"""
846Return:
847:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
848last_hidden_state (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
849Sequence of hidden-states at the last layer of the model.
850mems (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers`):
851Contains pre-computed hidden-states (key and values in the attention blocks).
852Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
853should not be passed as input ids as they have already been computed.
854hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
855Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for the output of the embeddings + one for the output of each layer)
856of shape :obj:`(batch_size, sequence_length, hidden_size)`.
857
858Hidden-states of the model at the output of each layer plus the initial embedding outputs.
859attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
860Tuple of :obj:`tf.Tensor` or :obj:`Numpy array` (one for each layer) of shape
861:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
862
863Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
864heads.
865"""
866outputs = self.transformer(inputs, **kwargs)
867return outputs
868
869
870@add_start_docstrings(
871"""XLNet Model with a language modeling head on top
872(linear layer with weights tied to the input embeddings). """,
873XLNET_START_DOCSTRING,
874)
875class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
876def __init__(self, config, *inputs, **kwargs):
877super().__init__(config, *inputs, **kwargs)
878self.transformer = TFXLNetMainLayer(config, name="transformer")
879self.lm_loss = TFXLNetLMHead(config, self.transformer.word_embedding, name="lm_loss")
880
881def get_output_embeddings(self):
882return self.lm_loss.input_embeddings
883
884def prepare_inputs_for_generation(self, inputs, past, **kwargs):
885# Add dummy token at the end (no attention on this one)
886
887# At every pass, the attention values for the new token and the two last generated tokens
888# are computed, the rest is reloaded from the `past` cache. A purely auto-regressive model would have
889# offset = 1; offset = 2 seems to have slightly better computation.
890offset = 2
891
892effective_batch_size = inputs.shape[0]
893dummy_token = tf.zeros((effective_batch_size, 1), dtype=tf.int32)
894
895if past:
896inputs = tf.concat([inputs[:, -offset:], dummy_token], axis=1)
897else:
898inputs = tf.concat([inputs, dummy_token], axis=1)
899
900# Build permutation mask so that previous tokens don't see last token
901sequence_length = inputs.shape[1]
902perm_mask = tf.zeros((effective_batch_size, sequence_length, sequence_length - 1), dtype=tf.float32)
903perm_mask_seq_end = tf.ones((effective_batch_size, sequence_length, 1), dtype=tf.float32)
904perm_mask = tf.concat([perm_mask, perm_mask_seq_end], axis=-1)
905
906# We'll only predict the last token
907target_mapping = tf.zeros((effective_batch_size, 1, sequence_length - 1), dtype=tf.float32)
908target_mapping_seq_end = tf.ones((effective_batch_size, 1, 1), dtype=tf.float32)
909target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1)
910
911inputs = {
912"inputs": inputs,
913"perm_mask": perm_mask,
914"target_mapping": target_mapping,
915"use_cache": kwargs["use_cache"],
916}
917
918# if past is defined in model kwargs then use it for faster decoding
919if past:
920inputs["mems"] = tuple(layer_past[:-offset, :, :] for layer_past in past)
921
922return inputs
923
924@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
925def call(
926self,
927inputs,
928attention_mask=None,
929mems=None,
930perm_mask=None,
931target_mapping=None,
932token_type_ids=None,
933input_mask=None,
934head_mask=None,
935inputs_embeds=None,
936use_cache=True,
937output_attentions=None,
938output_hidden_states=None,
939labels=None,
940training=False,
941):
942r"""
943labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
944Labels for computing the cross entropy classification loss.
945Indices should be in ``[0, ..., config.vocab_size - 1]``.
946
947Return:
948:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
949prediction_scores (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
950Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
951mems (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers`):
952Contains pre-computed hidden-states (key and values in the attention blocks).
953Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
954should not be passed as input ids as they have already been computed.
955hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
956tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
957of shape :obj:`(batch_size, sequence_length, hidden_size)`.
958
959Hidden-states of the model at the output of each layer plus the initial embedding outputs.
960attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
961tuple of :obj:`tf.Tensor` (one for each layer) of shape
962:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
963
964Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
965heads.
966
967Examples::
968
969import tensorflow as tf
970import numpy as np
971from transformers import XLNetTokenizer, TFXLNetLMHeadModel
972
973tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
974model = TFXLNetLMHeadModel.from_pretrained('xlnet-large-cased')
975
976# We show how to setup inputs to predict a next token using a bi-directional context.
977input_ids = tf.constant(tokenizer.encode("Hello, my dog is very <mask>", add_special_tokens=True))[None, :] # We will predict the masked token
978
979perm_mask = np.zeros((1, input_ids.shape[1], input_ids.shape[1]))
980perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
981
982target_mapping = np.zeros((1, 1, input_ids.shape[1])) # Shape [1, 1, seq_length] => let's predict one token
983target_mapping[0, 0, -1] = 1.0 # Our first (and only) prediction will be the last token of the sequence (the masked token)
984
985outputs = model(input_ids, perm_mask=tf.constant(perm_mask, dtype=tf.float32), target_mapping=tf.constant(target_mapping, dtype=tf.float32))
986
987next_token_logits = outputs[0] # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
988
989"""
990if isinstance(inputs, (tuple, list)):
991labels = inputs[12] if len(inputs) > 12 else labels
992if len(inputs) > 12:
993inputs = inputs[:12]
994elif isinstance(inputs, (dict, BatchEncoding)):
995labels = inputs.pop("labels", labels)
996
997transformer_outputs = self.transformer(
998inputs,
999attention_mask=None,
1000mems=None,
1001perm_mask=None,
1002target_mapping=None,
1003token_type_ids=None,
1004input_mask=None,
1005head_mask=None,
1006inputs_embeds=None,
1007use_cache=True,
1008output_attentions=None,
1009output_hidden_states=None,
1010training=training,
1011)
1012hidden_state = transformer_outputs[0]
1013logits = self.lm_loss(hidden_state, training=training)
1014
1015outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
1016
1017if labels is not None:
1018# shift labels to the left and cut last logit token
1019logits = logits[:, :-1]
1020labels = labels[:, 1:]
1021loss = self.compute_loss(labels, logits)
1022outputs = (loss,) + outputs
1023
1024return outputs # return logits, (mems), (hidden states), (attentions)
1025
1026
1027@add_start_docstrings(
1028"""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
1029the pooled output) e.g. for GLUE tasks. """,
1030XLNET_START_DOCSTRING,
1031)
1032class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel, TFSequenceClassificationLoss):
1033def __init__(self, config, *inputs, **kwargs):
1034super().__init__(config, *inputs, **kwargs)
1035self.num_labels = config.num_labels
1036
1037self.transformer = TFXLNetMainLayer(config, name="transformer")
1038self.sequence_summary = TFSequenceSummary(
1039config, initializer_range=config.initializer_range, name="sequence_summary"
1040)
1041self.logits_proj = tf.keras.layers.Dense(
1042config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="logits_proj"
1043)
1044
1045@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
1046@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlnet-base-cased")
1047def call(
1048self,
1049inputs=None,
1050attention_mask=None,
1051mems=None,
1052perm_mask=None,
1053target_mapping=None,
1054token_type_ids=None,
1055input_mask=None,
1056head_mask=None,
1057inputs_embeds=None,
1058use_cache=True,
1059output_attentions=None,
1060output_hidden_states=None,
1061labels=None,
1062training=False,
1063):
1064r"""
1065labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1066Labels for computing the sequence classification/regression loss.
1067Indices should be in ``[0, ..., config.num_labels - 1]``.
1068If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
1069If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
1070
1071Return:
1072:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
1073logits (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:(batch_size, config.num_labels)`):
1074Classification (or regression if config.num_labels==1) scores (before SoftMax).
1075mems (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers`):
1076Contains pre-computed hidden-states (key and values in the attention blocks).
1077Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
1078should not be passed as input ids as they have already been computed.
1079hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1080tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
1081of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1082
1083Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1084attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1085tuple of :obj:`tf.Tensor` (one for each layer) of shape
1086:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
1087
1088Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1089heads.
1090"""
1091if isinstance(inputs, (tuple, list)):
1092labels = inputs[12] if len(inputs) > 12 else labels
1093if len(inputs) > 12:
1094inputs = inputs[:12]
1095elif isinstance(inputs, (dict, BatchEncoding)):
1096labels = inputs.pop("labels", labels)
1097
1098transformer_outputs = self.transformer(
1099inputs,
1100attention_mask=attention_mask,
1101mems=mems,
1102perm_mask=perm_mask,
1103target_mapping=target_mapping,
1104token_type_ids=token_type_ids,
1105input_mask=input_mask,
1106head_mask=head_mask,
1107inputs_embeds=inputs_embeds,
1108use_cache=use_cache,
1109output_attentions=output_attentions,
1110output_hidden_states=output_hidden_states,
1111)
1112output = transformer_outputs[0]
1113
1114output = self.sequence_summary(output)
1115logits = self.logits_proj(output)
1116
1117outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
1118
1119if labels is not None:
1120loss = self.compute_loss(labels, logits)
1121outputs = (loss,) + outputs
1122
1123return outputs # (loss), logits, (hidden_states), (attentions)
1124
1125
1126@add_start_docstrings(
1127"""XLNET Model with a multiple choice classification head on top (a linear layer on top of
1128the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
1129XLNET_START_DOCSTRING,
1130)
1131class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
1132def __init__(self, config, *inputs, **kwargs):
1133super().__init__(config, *inputs, **kwargs)
1134
1135self.transformer = TFXLNetMainLayer(config, name="transformer")
1136self.sequence_summary = TFSequenceSummary(
1137config, initializer_range=config.initializer_range, name="sequence_summary"
1138)
1139self.logits_proj = tf.keras.layers.Dense(
11401, kernel_initializer=get_initializer(config.initializer_range), name="logits_proj"
1141)
1142
1143@property
1144def dummy_inputs(self):
1145""" Dummy inputs to build the network.
1146
1147Returns:
1148tf.Tensor with dummy inputs
1149"""
1150return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
1151
1152@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
1153@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlnet-base-cased")
1154def call(
1155self,
1156inputs=None,
1157token_type_ids=None,
1158input_mask=None,
1159attention_mask=None,
1160mems=None,
1161perm_mask=None,
1162target_mapping=None,
1163head_mask=None,
1164inputs_embeds=None,
1165use_cache=True,
1166output_attentions=None,
1167output_hidden_states=None,
1168labels=None,
1169training=False,
1170):
1171r"""
1172labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1173Labels for computing the multiple choice classification loss.
1174Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
1175of the input tensors. (see `input_ids` above)
1176
1177Return:
1178:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1179classification_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`:
1180`num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above).
1181
1182Classification scores (before SoftMax).
1183hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1184tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
1185of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1186
1187Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1188attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1189tuple of :obj:`tf.Tensor` (one for each layer) of shape
1190:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
1191
1192Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1193heads.
1194"""
1195if isinstance(inputs, (tuple, list)):
1196input_ids = inputs[0]
1197attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
1198mems = inputs[2] if len(inputs) > 2 else mems
1199perm_mask = inputs[3] if len(inputs) > 3 else perm_mask
1200target_mapping = inputs[4] if len(inputs) > 4 else target_mapping
1201token_type_ids = inputs[5] if len(inputs) > 5 else token_type_ids
1202input_mask = inputs[6] if len(inputs) > 6 else input_mask
1203head_mask = inputs[7] if len(inputs) > 7 else head_mask
1204inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
1205use_cache = inputs[9] if len(inputs) > 9 else use_cache
1206output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
1207output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
1208labels = inputs[12] if len(inputs) > 12 else labels
1209assert len(inputs) <= 13, "Too many inputs."
1210elif isinstance(inputs, (dict, BatchEncoding)):
1211input_ids = inputs.get("input_ids")
1212attention_mask = inputs.get("attention_mask", attention_mask)
1213mems = inputs.get("mems", mems)
1214perm_mask = inputs.get("perm_mask", perm_mask)
1215target_mapping = inputs.get("target_mapping", target_mapping)
1216token_type_ids = inputs.get("token_type_ids", token_type_ids)
1217input_mask = inputs.get("input_mask", input_mask)
1218head_mask = inputs.get("head_mask", head_mask)
1219inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
1220use_cache = inputs.get("use_cache", use_cache)
1221output_attentions = inputs.get("output_attentions", output_attentions)
1222output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
1223labels = inputs.get("labels", labels)
1224assert len(inputs) <= 13, "Too many inputs."
1225else:
1226input_ids = inputs
1227
1228if input_ids is not None:
1229num_choices = shape_list(input_ids)[1]
1230seq_length = shape_list(input_ids)[2]
1231else:
1232num_choices = shape_list(inputs_embeds)[1]
1233seq_length = shape_list(inputs_embeds)[2]
1234
1235flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
1236flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
1237flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
1238flat_input_mask = tf.reshape(input_mask, (-1, seq_length)) if input_mask is not None else None
1239flat_inputs_embeds = (
1240tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
1241if inputs_embeds is not None
1242else None
1243)
1244
1245flat_inputs = [
1246flat_input_ids,
1247flat_attention_mask,
1248mems,
1249perm_mask,
1250target_mapping,
1251flat_token_type_ids,
1252flat_input_mask,
1253head_mask,
1254flat_inputs_embeds,
1255use_cache,
1256output_attentions,
1257output_hidden_states,
1258]
1259
1260transformer_outputs = self.transformer(flat_inputs, training=training)
1261output = transformer_outputs[0]
1262logits = self.sequence_summary(output)
1263logits = self.logits_proj(logits)
1264reshaped_logits = tf.reshape(logits, (-1, num_choices))
1265
1266outputs = (reshaped_logits,) + transformer_outputs[1:] # add hidden states and attention if they are here
1267
1268if labels is not None:
1269loss = self.compute_loss(labels, reshaped_logits)
1270outputs = (loss,) + outputs
1271
1272return outputs # (loss), logits, (mems), (hidden states), (attentions)
1273
1274
1275@add_start_docstrings(
1276"""XLNet Model with a token classification head on top (a linear layer on top of
1277the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
1278XLNET_START_DOCSTRING,
1279)
1280class TFXLNetForTokenClassification(TFXLNetPreTrainedModel, TFTokenClassificationLoss):
1281def __init__(self, config, *inputs, **kwargs):
1282super().__init__(config, *inputs, **kwargs)
1283self.num_labels = config.num_labels
1284
1285self.transformer = TFXLNetMainLayer(config, name="transformer")
1286self.classifier = tf.keras.layers.Dense(
1287config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
1288)
1289
1290@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
1291@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlnet-base-cased")
1292def call(
1293self,
1294inputs=None,
1295attention_mask=None,
1296mems=None,
1297perm_mask=None,
1298target_mapping=None,
1299token_type_ids=None,
1300input_mask=None,
1301head_mask=None,
1302inputs_embeds=None,
1303use_cache=True,
1304output_attentions=None,
1305output_hidden_states=None,
1306labels=None,
1307training=False,
1308):
1309r"""
1310labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
1311Labels for computing the token classification loss.
1312Indices should be in ``[0, ..., config.num_labels - 1]``.
1313
1314Return:
1315:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
1316logits (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:(batch_size, config.num_labels)`):
1317Classification scores (before SoftMax).
1318mems (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers`):
1319Contains pre-computed hidden-states (key and values in the attention blocks).
1320Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
1321should not be passed as input ids as they have already been computed.
1322hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1323tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
1324of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1325
1326Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1327attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1328tuple of :obj:`tf.Tensor` (one for each layer) of shape
1329:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
1330
1331Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1332heads.
1333"""
1334if isinstance(inputs, (tuple, list)):
1335labels = inputs[12] if len(inputs) > 12 else labels
1336if len(inputs) > 12:
1337inputs = inputs[:12]
1338elif isinstance(inputs, (dict, BatchEncoding)):
1339labels = inputs.pop("labels", labels)
1340
1341transformer_outputs = self.transformer(
1342inputs,
1343attention_mask=attention_mask,
1344mems=mems,
1345perm_mask=perm_mask,
1346target_mapping=target_mapping,
1347token_type_ids=token_type_ids,
1348input_mask=input_mask,
1349head_mask=head_mask,
1350inputs_embeds=inputs_embeds,
1351use_cache=use_cache,
1352output_attentions=output_attentions,
1353output_hidden_states=output_hidden_states,
1354training=training,
1355)
1356output = transformer_outputs[0]
1357
1358logits = self.classifier(output)
1359
1360outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
1361
1362if labels is not None:
1363loss = self.compute_loss(labels, logits)
1364outputs = (loss,) + outputs
1365
1366return outputs # (loss), logits, (hidden_states), (attentions)
1367
1368
1369@add_start_docstrings(
1370"""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
1371the hidden-states output to compute `span start logits` and `span end logits`). """,
1372XLNET_START_DOCSTRING,
1373)
1374class TFXLNetForQuestionAnsweringSimple(TFXLNetPreTrainedModel, TFQuestionAnsweringLoss):
1375def __init__(self, config, *inputs, **kwargs):
1376super().__init__(config, *inputs, **kwargs)
1377self.transformer = TFXLNetMainLayer(config, name="transformer")
1378self.qa_outputs = tf.keras.layers.Dense(
1379config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
1380)
1381
1382@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
1383@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="xlnet-base-cased")
1384def call(
1385self,
1386inputs=None,
1387attention_mask=None,
1388mems=None,
1389perm_mask=None,
1390target_mapping=None,
1391token_type_ids=None,
1392input_mask=None,
1393head_mask=None,
1394inputs_embeds=None,
1395use_cache=True,
1396output_attentions=None,
1397output_hidden_states=None,
1398start_positions=None,
1399end_positions=None,
1400training=False,
1401):
1402r"""
1403start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1404Labels for position (index) of the start of the labelled span for computing the token classification loss.
1405Positions are clamped to the length of the sequence (`sequence_length`).
1406Position outside of the sequence are not taken into account for computing the loss.
1407end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1408Labels for position (index) of the end of the labelled span for computing the token classification loss.
1409Positions are clamped to the length of the sequence (`sequence_length`).
1410Position outside of the sequence are not taken into account for computing the loss.
1411
1412Returns:
1413:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.XLNetConfig`) and inputs:
1414loss (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
1415Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
1416start_scores (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length,)`):
1417Span-start scores (before SoftMax).
1418end_scores (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length,)`):
1419Span-end scores (before SoftMax).
1420mems (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers`):
1421Contains pre-computed hidden-states (key and values in the attention blocks).
1422Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
1423should not be passed as input ids as they have already been computed.
1424hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1425tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
1426of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1427
1428Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1429attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1430tuple of :obj:`tf.Tensor` (one for each layer) of shape
1431:obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
1432
1433Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1434heads.
1435"""
1436if isinstance(inputs, (tuple, list)):
1437start_positions = inputs[12] if len(inputs) > 12 else start_positions
1438end_positions = inputs[13] if len(inputs) > 13 else end_positions
1439if len(inputs) > 12:
1440inputs = inputs[:12]
1441elif isinstance(inputs, (dict, BatchEncoding)):
1442start_positions = inputs.pop("start_positions", start_positions)
1443end_positions = inputs.pop("end_positions", start_positions)
1444
1445transformer_outputs = self.transformer(
1446inputs,
1447attention_mask=attention_mask,
1448mems=mems,
1449perm_mask=perm_mask,
1450target_mapping=target_mapping,
1451token_type_ids=token_type_ids,
1452input_mask=input_mask,
1453head_mask=head_mask,
1454inputs_embeds=inputs_embeds,
1455use_cache=use_cache,
1456output_attentions=output_attentions,
1457output_hidden_states=output_hidden_states,
1458training=training,
1459)
1460
1461sequence_output = transformer_outputs[0]
1462
1463logits = self.qa_outputs(sequence_output)
1464start_logits, end_logits = tf.split(logits, 2, axis=-1)
1465start_logits = tf.squeeze(start_logits, axis=-1)
1466end_logits = tf.squeeze(end_logits, axis=-1)
1467
1468outputs = (start_logits, end_logits,) + transformer_outputs[
14691:
1470] # Keep mems, hidden states, attentions if there are in it
1471
1472if start_positions is not None and end_positions is not None:
1473labels = {"start_position": start_positions}
1474labels["end_position"] = end_positions
1475loss = self.compute_loss(labels, outputs[:2])
1476outputs = (loss,) + outputs
1477
1478return outputs # (loss), start_logits, end_logits, (mems), (hidden_states), (attentions)
1479
1480
1481# @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
1482# the hidden-states output to compute `span start logits` and `span end logits`). """,
1483# XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
1484# class TFXLNetForQuestionAnswering(TFXLNetPreTrainedModel):
1485# r"""
1486# Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
1487# **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1488# ``tf.Tensor`` of shape ``(batch_size, config.start_n_top)``
1489# Log probabilities for the top config.start_n_top start token possibilities (beam-search).
1490# **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1491# ``tf.Tensor`` of shape ``(batch_size, config.start_n_top)``
1492# Indices for the top config.start_n_top start token possibilities (beam-search).
1493# **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1494# ``tf.Tensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
1495# Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
1496# **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1497# ``tf.Tensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
1498# Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
1499# **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
1500# ``tf.Tensor`` of shape ``(batch_size,)``
1501# Log probabilities for the ``is_impossible`` label of the answers.
1502# **mems**:
1503# list of ``tf.Tensor`` (one for each layer):
1504# that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
1505# if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
1506# See details in the docstring of the `mems` input above.
1507# **hidden_states**: (`optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``)
1508# list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
1509# of shape ``(batch_size, sequence_length, hidden_size)``:
1510# Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1511# **attentions**: (`optional`, returned when ``output_attentions=True``)
1512# list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
1513# Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
1514
1515# Examples::
1516
1517# # For example purposes. Not runnable.
1518# tokenizer = XLMTokenizer.from_pretrained('xlm-mlm-en-2048')
1519# model = XLMForQuestionAnswering.from_pretrained('xlnet-large-cased')
1520# input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
1521# start_positions = tf.constant([1])
1522# end_positions = tf.constant([3])
1523# outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
1524# loss, start_scores, end_scores = outputs[:2]
1525
1526# """
1527# def __init__(self, config, *inputs, **kwargs):
1528# super().__init__(config, *inputs, **kwargs)
1529# self.start_n_top = config.start_n_top
1530# self.end_n_top = config.end_n_top
1531
1532# self.transformer = TFXLNetMainLayer(config, name='transformer')
1533# self.start_logits = TFPoolerStartLogits(config, name='start_logits')
1534# self.end_logits = TFPoolerEndLogits(config, name='end_logits')
1535# self.answer_class = TFPoolerAnswerClass(config, name='answer_class')
1536
1537# def call(self, inputs, training=False):
1538# transformer_outputs = self.transformer(inputs, training=training)
1539# hidden_states = transformer_outputs[0]
1540# start_logits = self.start_logits(hidden_states, p_mask=p_mask)
1541
1542# outputs = transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
1543
1544# if start_positions is not None and end_positions is not None:
1545# # If we are on multi-GPU, let's remove the dimension added by batch splitting
1546# for x in (start_positions, end_positions, cls_index, is_impossible):
1547# if x is not None and x.dim() > 1:
1548# x.squeeze_(-1)
1549
1550# # during training, compute the end logits based on the ground truth of the start position
1551# end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
1552
1553# loss_fct = CrossEntropyLoss()
1554# start_loss = loss_fct(start_logits, start_positions)
1555# end_loss = loss_fct(end_logits, end_positions)
1556# total_loss = (start_loss + end_loss) / 2
1557
1558# if cls_index is not None and is_impossible is not None:
1559# # Predict answerability from the representation of CLS and START
1560# cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
1561# loss_fct_cls = nn.BCEWithLogitsLoss()
1562# cls_loss = loss_fct_cls(cls_logits, is_impossible)
1563
1564# # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
1565# total_loss += cls_loss * 0.5
1566
1567# outputs = (total_loss,) + outputs
1568
1569# else:
1570# # during inference, compute the end logits based on beam search
1571# bsz, slen, hsz = hidden_states.size()
1572# start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
1573
1574# start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
1575# start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
1576# start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
1577# start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
1578
1579# hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
1580# p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
1581# end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
1582# end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
1583
1584# end_top_log_probs, end_top_index = torch.topk(end_log_probs, self.end_n_top, dim=1) # shape (bsz, end_n_top, start_n_top)
1585# end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
1586# end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
1587
1588# start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) # get the representation of START as weighted sum of hidden states
1589# cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) # Shape (batch size,): one single `cls_logits` for each sample
1590
1591# outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
1592
1593# # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
1594# # or (if labels are provided) (total_loss,)
1595# return outputs
1596