CSS-LM
1089 строк · 45.2 Кб
1# coding=utf-8
2# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
3# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16""" PyTorch Transformer XL model.
17Adapted from https://github.com/kimiyoung/transformer-xl.
18In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
19"""
20
21
22import logging
23from dataclasses import dataclass
24from typing import List, Optional, Tuple
25
26import torch
27import torch.nn as nn
28import torch.nn.functional as F
29
30from .configuration_transfo_xl import TransfoXLConfig
31from .file_utils import ModelOutput, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
32from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax
33from .modeling_utils import PreTrainedModel
34
35
36logger = logging.getLogger(__name__)
37
38_CONFIG_FOR_DOC = "TransfoXLConfig"
39_TOKENIZER_FOR_DOC = "TransfoXLTokenizer"
40
41TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = [
42"transfo-xl-wt103",
43# See all Transformer XL models at https://huggingface.co/models?filter=transfo-xl
44]
45
46
47def build_tf_to_pytorch_map(model, config):
48""" A map of modules from TF to PyTorch.
49This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
50"""
51tf_to_pt_map = {}
52
53if hasattr(model, "transformer"):
54# We are loading in a TransfoXLLMHeadModel => we will load also the Adaptive Softmax
55tf_to_pt_map.update(
56{
57"transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight,
58"transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias,
59}
60)
61for i, (out_l, proj_l, tie_proj) in enumerate(
62zip(model.crit.out_layers, model.crit.out_projs, config.tie_projs)
63):
64layer_str = "transformer/adaptive_softmax/cutoff_%d/" % i
65if config.tie_weight:
66tf_to_pt_map.update({layer_str + "b": out_l.bias})
67else:
68raise NotImplementedError
69# I don't think this is implemented in the TF code
70tf_to_pt_map.update({layer_str + "lookup_table": out_l.weight, layer_str + "b": out_l.bias})
71if not tie_proj:
72tf_to_pt_map.update({layer_str + "proj": proj_l})
73# Now load the rest of the transformer
74model = model.transformer
75
76# Embeddings
77for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)):
78layer_str = "transformer/adaptive_embed/cutoff_%d/" % i
79tf_to_pt_map.update({layer_str + "lookup_table": embed_l.weight, layer_str + "proj_W": proj_l})
80
81# Transformer blocks
82for i, b in enumerate(model.layers):
83layer_str = "transformer/layer_%d/" % i
84tf_to_pt_map.update(
85{
86layer_str + "rel_attn/LayerNorm/gamma": b.dec_attn.layer_norm.weight,
87layer_str + "rel_attn/LayerNorm/beta": b.dec_attn.layer_norm.bias,
88layer_str + "rel_attn/o/kernel": b.dec_attn.o_net.weight,
89layer_str + "rel_attn/qkv/kernel": b.dec_attn.qkv_net.weight,
90layer_str + "rel_attn/r/kernel": b.dec_attn.r_net.weight,
91layer_str + "ff/LayerNorm/gamma": b.pos_ff.layer_norm.weight,
92layer_str + "ff/LayerNorm/beta": b.pos_ff.layer_norm.bias,
93layer_str + "ff/layer_1/kernel": b.pos_ff.CoreNet[0].weight,
94layer_str + "ff/layer_1/bias": b.pos_ff.CoreNet[0].bias,
95layer_str + "ff/layer_2/kernel": b.pos_ff.CoreNet[3].weight,
96layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias,
97}
98)
99
100# Relative positioning biases
101if config.untie_r:
102r_r_list = []
103r_w_list = []
104for b in model.layers:
105r_r_list.append(b.dec_attn.r_r_bias)
106r_w_list.append(b.dec_attn.r_w_bias)
107else:
108r_r_list = [model.r_r_bias]
109r_w_list = [model.r_w_bias]
110tf_to_pt_map.update({"transformer/r_r_bias": r_r_list, "transformer/r_w_bias": r_w_list})
111return tf_to_pt_map
112
113
114def load_tf_weights_in_transfo_xl(model, config, tf_path):
115""" Load tf checkpoints in a pytorch model
116"""
117try:
118import numpy as np
119import tensorflow as tf
120except ImportError:
121logger.error(
122"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
123"https://www.tensorflow.org/install/ for installation instructions."
124)
125raise
126# Build TF to PyTorch weights loading map
127tf_to_pt_map = build_tf_to_pytorch_map(model, config)
128
129# Load weights from TF model
130init_vars = tf.train.list_variables(tf_path)
131tf_weights = {}
132for name, shape in init_vars:
133logger.info("Loading TF weight {} with shape {}".format(name, shape))
134array = tf.train.load_variable(tf_path, name)
135tf_weights[name] = array
136
137for name, pointer in tf_to_pt_map.items():
138assert name in tf_weights
139array = tf_weights[name]
140# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
141# which are not required for using pretrained model
142if "kernel" in name or "proj" in name:
143array = np.transpose(array)
144if ("r_r_bias" in name or "r_w_bias" in name) and len(pointer) > 1:
145# Here we will split the TF weights
146assert len(pointer) == array.shape[0]
147for i, p_i in enumerate(pointer):
148arr_i = array[i, ...]
149try:
150assert p_i.shape == arr_i.shape
151except AssertionError as e:
152e.args += (p_i.shape, arr_i.shape)
153raise
154logger.info("Initialize PyTorch weight {} for layer {}".format(name, i))
155p_i.data = torch.from_numpy(arr_i)
156else:
157try:
158assert pointer.shape == array.shape
159except AssertionError as e:
160e.args += (pointer.shape, array.shape)
161raise
162logger.info("Initialize PyTorch weight {}".format(name))
163pointer.data = torch.from_numpy(array)
164tf_weights.pop(name, None)
165tf_weights.pop(name + "/Adam", None)
166tf_weights.pop(name + "/Adam_1", None)
167
168logger.info("Weights not copied to PyTorch model: {}".format(", ".join(tf_weights.keys())))
169return model
170
171
172class PositionalEmbedding(nn.Module):
173def __init__(self, demb):
174super().__init__()
175
176self.demb = demb
177
178inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
179self.register_buffer("inv_freq", inv_freq)
180
181def forward(self, pos_seq, bsz=None):
182sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
183pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
184
185if bsz is not None:
186return pos_emb[:, None, :].expand(-1, bsz, -1)
187else:
188return pos_emb[:, None, :]
189
190
191class PositionwiseFF(nn.Module):
192def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, layer_norm_epsilon=1e-5):
193super().__init__()
194
195self.d_model = d_model
196self.d_inner = d_inner
197self.dropout = dropout
198
199self.CoreNet = nn.Sequential(
200nn.Linear(d_model, d_inner),
201nn.ReLU(inplace=True),
202nn.Dropout(dropout),
203nn.Linear(d_inner, d_model),
204nn.Dropout(dropout),
205)
206
207self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)
208
209self.pre_lnorm = pre_lnorm
210
211def forward(self, inp):
212if self.pre_lnorm:
213# layer normalization + positionwise feed-forward
214core_out = self.CoreNet(self.layer_norm(inp))
215
216# residual connection
217output = core_out + inp
218else:
219# positionwise feed-forward
220core_out = self.CoreNet(inp)
221
222# residual connection + layer normalization
223output = self.layer_norm(inp + core_out)
224
225return output
226
227
228class RelPartialLearnableMultiHeadAttn(nn.Module):
229def __init__(
230self,
231n_head,
232d_model,
233d_head,
234dropout,
235dropatt=0,
236tgt_len=None,
237ext_len=None,
238mem_len=None,
239pre_lnorm=False,
240r_r_bias=None,
241r_w_bias=None,
242layer_norm_epsilon=1e-5,
243):
244super().__init__()
245
246self.n_head = n_head
247self.d_model = d_model
248self.d_head = d_head
249self.dropout = dropout
250
251self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)
252
253self.drop = nn.Dropout(dropout)
254self.dropatt = nn.Dropout(dropatt)
255self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
256
257self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)
258
259self.scale = 1 / (d_head ** 0.5)
260
261self.pre_lnorm = pre_lnorm
262
263if r_r_bias is None or r_w_bias is None: # Biases are not shared
264self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
265self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
266else:
267self.r_r_bias = r_r_bias
268self.r_w_bias = r_w_bias
269
270self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
271
272def _rel_shift(self, x):
273zero_pad_shape = (x.size(0), 1) + x.size()[2:]
274zero_pad = torch.zeros(zero_pad_shape, device=x.device, dtype=x.dtype)
275x_padded = torch.cat([zero_pad, x], dim=1)
276
277x_padded_shape = (x.size(1) + 1, x.size(0)) + x.size()[2:]
278x_padded = x_padded.view(*x_padded_shape)
279
280x = x_padded[1:].view_as(x)
281
282return x
283
284def forward(self, w, r, attn_mask=None, mems=None, head_mask=None, output_attentions=False):
285qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
286
287if mems is not None:
288cat = torch.cat([mems, w], 0)
289if self.pre_lnorm:
290w_heads = self.qkv_net(self.layer_norm(cat))
291else:
292w_heads = self.qkv_net(cat)
293r_head_k = self.r_net(r)
294
295w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
296w_head_q = w_head_q[-qlen:]
297else:
298if self.pre_lnorm:
299w_heads = self.qkv_net(self.layer_norm(w))
300else:
301w_heads = self.qkv_net(w)
302r_head_k = self.r_net(r)
303
304w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
305
306klen = w_head_k.size(0)
307
308w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
309w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
310w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
311
312r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
313
314# compute attention score
315rw_head_q = w_head_q + self.r_w_bias # qlen x bsz x n_head x d_head
316AC = torch.einsum("ibnd,jbnd->ijbn", (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
317
318rr_head_q = w_head_q + self.r_r_bias
319BD = torch.einsum("ibnd,jnd->ijbn", (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head
320BD = self._rel_shift(BD)
321
322# [qlen x klen x bsz x n_head]
323attn_score = AC + BD
324attn_score.mul_(self.scale)
325
326# compute attention probability
327if attn_mask is not None and torch.sum(attn_mask).item():
328attn_mask = attn_mask == 1 # Switch to bool
329if attn_mask.dim() == 2:
330if next(self.parameters()).dtype == torch.float16:
331attn_score = (
332attn_score.float().masked_fill(attn_mask[None, :, :, None], -65000).type_as(attn_score)
333)
334else:
335attn_score = attn_score.float().masked_fill(attn_mask[None, :, :, None], -1e30).type_as(attn_score)
336elif attn_mask.dim() == 3:
337if next(self.parameters()).dtype == torch.float16:
338attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], -65000).type_as(attn_score)
339else:
340attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], -1e30).type_as(attn_score)
341
342# [qlen x klen x bsz x n_head]
343attn_prob = F.softmax(attn_score, dim=1)
344attn_prob = self.dropatt(attn_prob)
345
346# Mask heads if we want to
347if head_mask is not None:
348attn_prob = attn_prob * head_mask
349
350# compute attention vector
351attn_vec = torch.einsum("ijbn,jbnd->ibnd", (attn_prob, w_head_v))
352
353# [qlen x bsz x n_head x d_head]
354attn_vec = attn_vec.contiguous().view(attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
355
356# linear projection
357attn_out = self.o_net(attn_vec)
358attn_out = self.drop(attn_out)
359
360if self.pre_lnorm:
361# residual connection
362outputs = [w + attn_out]
363else:
364# residual connection + layer normalization
365outputs = [self.layer_norm(w + attn_out)]
366
367if output_attentions:
368outputs.append(attn_prob)
369
370return outputs
371
372
373class RelPartialLearnableDecoderLayer(nn.Module):
374def __init__(self, n_head, d_model, d_head, d_inner, dropout, layer_norm_epsilon=1e-5, **kwargs):
375super().__init__()
376
377self.dec_attn = RelPartialLearnableMultiHeadAttn(
378n_head, d_model, d_head, dropout, layer_norm_epsilon=layer_norm_epsilon, **kwargs
379)
380self.pos_ff = PositionwiseFF(
381d_model, d_inner, dropout, pre_lnorm=kwargs.get("pre_lnorm"), layer_norm_epsilon=layer_norm_epsilon
382)
383
384def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None, output_attentions=False):
385
386attn_outputs = self.dec_attn(
387dec_inp, r, attn_mask=dec_attn_mask, mems=mems, head_mask=head_mask, output_attentions=output_attentions,
388)
389ff_output = self.pos_ff(attn_outputs[0])
390
391outputs = [ff_output] + attn_outputs[1:]
392
393return outputs
394
395
396class AdaptiveEmbedding(nn.Module):
397def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, sample_softmax=False):
398super().__init__()
399
400self.n_token = n_token
401self.d_embed = d_embed
402
403self.cutoffs = cutoffs + [n_token]
404self.div_val = div_val
405self.d_proj = d_proj
406
407self.emb_scale = d_proj ** 0.5
408
409self.cutoff_ends = [0] + self.cutoffs
410
411self.emb_layers = nn.ModuleList()
412self.emb_projs = nn.ParameterList()
413if div_val == 1:
414self.emb_layers.append(nn.Embedding(n_token, d_embed, sparse=sample_softmax > 0))
415if d_proj != d_embed:
416self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed)))
417else:
418for i in range(len(self.cutoffs)):
419l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
420d_emb_i = d_embed // (div_val ** i)
421self.emb_layers.append(nn.Embedding(r_idx - l_idx, d_emb_i))
422self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i)))
423
424def forward(self, inp):
425if self.div_val == 1:
426embed = self.emb_layers[0](inp)
427if self.d_proj != self.d_embed:
428embed = F.linear(embed, self.emb_projs[0])
429else:
430param = next(self.parameters())
431inp_flat = inp.view(-1)
432emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], dtype=param.dtype, device=param.device)
433for i in range(len(self.cutoffs)):
434l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
435
436mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
437indices_i = mask_i.nonzero().squeeze()
438
439if indices_i.numel() == 0:
440continue
441
442inp_i = inp_flat.index_select(0, indices_i) - l_idx
443emb_i = self.emb_layers[i](inp_i)
444emb_i = F.linear(emb_i, self.emb_projs[i])
445
446emb_flat.index_copy_(0, indices_i, emb_i)
447
448embed_shape = inp.size() + (self.d_proj,)
449embed = emb_flat.view(embed_shape)
450
451embed.mul_(self.emb_scale)
452
453return embed
454
455
456class TransfoXLPreTrainedModel(PreTrainedModel):
457""" An abstract class to handle weights initialization and
458a simple interface for downloading and loading pretrained models.
459"""
460
461config_class = TransfoXLConfig
462load_tf_weights = load_tf_weights_in_transfo_xl
463base_model_prefix = "transformer"
464
465def _init_weight(self, weight):
466if self.config.init == "uniform":
467nn.init.uniform_(weight, -self.config.init_range, self.config.init_range)
468elif self.config.init == "normal":
469nn.init.normal_(weight, 0.0, self.config.init_std)
470
471def _init_bias(self, bias):
472nn.init.constant_(bias, 0.0)
473
474def _init_weights(self, m):
475""" Initialize the weights.
476"""
477classname = m.__class__.__name__
478if classname.find("Linear") != -1:
479if hasattr(m, "weight") and m.weight is not None:
480self._init_weight(m.weight)
481if hasattr(m, "bias") and m.bias is not None:
482self._init_bias(m.bias)
483elif classname.find("AdaptiveEmbedding") != -1:
484if hasattr(m, "emb_projs"):
485for i in range(len(m.emb_projs)):
486if m.emb_projs[i] is not None:
487nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std)
488elif classname.find("Embedding") != -1:
489if hasattr(m, "weight"):
490self._init_weight(m.weight)
491elif classname.find("ProjectedAdaptiveLogSoftmax") != -1:
492if hasattr(m, "cluster_weight") and m.cluster_weight is not None:
493self._init_weight(m.cluster_weight)
494if hasattr(m, "cluster_bias") and m.cluster_bias is not None:
495self._init_bias(m.cluster_bias)
496if hasattr(m, "out_projs"):
497for i in range(len(m.out_projs)):
498if m.out_projs[i] is not None:
499nn.init.normal_(m.out_projs[i], 0.0, self.config.proj_init_std)
500elif classname.find("LayerNorm") != -1:
501if hasattr(m, "weight"):
502nn.init.normal_(m.weight, 1.0, self.config.init_std)
503if hasattr(m, "bias") and m.bias is not None:
504self._init_bias(m.bias)
505else:
506if hasattr(m, "r_emb"):
507self._init_weight(m.r_emb)
508if hasattr(m, "r_w_bias"):
509self._init_weight(m.r_w_bias)
510if hasattr(m, "r_r_bias"):
511self._init_weight(m.r_r_bias)
512if hasattr(m, "r_bias"):
513self._init_bias(m.r_bias)
514
515def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, layer: Optional[int] = -1):
516""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
517Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
518
519Arguments:
520
521new_num_tokens: (`optional`) int:
522New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
523If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
524layer: (`optional`) int:
525Layer of the `AdaptiveEmbedding` where the resizing should be done. Per default the last layer will be resized.
526Be aware that when resizing other than the last layer, you have to ensure that the new token(s) in the tokenizer are at the corresponding position.
527
528Return: ``torch.nn.Embeddings``
529Pointer to the input tokens Embeddings Module of the model
530"""
531base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
532
533if new_num_tokens is None:
534return self.get_input_embeddings()
535
536new_num_tokens_layer, layer = self._get_new_num_tokens_layer(new_num_tokens, layer)
537assert new_num_tokens_layer > 0, "The size of the new embedding layer cannot be 0 or less"
538model_embeds = base_model._resize_token_embeddings(new_num_tokens_layer, layer)
539
540# Update base model and current model config
541self.config.vocab_size = new_num_tokens
542base_model.vocab_size = new_num_tokens
543base_model.n_token = new_num_tokens
544
545new_embedding_shapes = self._get_embedding_shapes()
546self._resize_cutoffs(new_num_tokens, new_num_tokens_layer, new_embedding_shapes, layer)
547
548# Tie weights again if needed
549self.tie_weights()
550
551return model_embeds
552
553def _get_new_num_tokens_layer(self, new_num_tokens, layer):
554embeddings = self.get_input_embeddings()
555if layer == -1:
556layer = len(embeddings.emb_layers) - 1
557assert 0 <= layer <= len(embeddings.emb_layers) - 1
558
559new_num_tokens_layer = (
560new_num_tokens
561- sum([emb.weight.shape[0] for emb in embeddings.emb_layers[:layer]])
562- sum([emb.weight.shape[0] for emb in embeddings.emb_layers[layer + 1 :]])
563)
564return new_num_tokens_layer, layer
565
566def _get_embedding_shapes(self):
567embeddings = self.get_input_embeddings()
568return [emb.weight.shape[0] for emb in embeddings.emb_layers]
569
570def _resize_token_embeddings(self, new_num_tokens, layer=-1):
571embeddings = self.get_input_embeddings()
572if new_num_tokens is None:
573return embeddings
574new_embeddings_layer = self._get_resized_embeddings(embeddings.emb_layers[layer], new_num_tokens)
575embeddings.emb_layers[layer] = new_embeddings_layer
576
577self.set_input_embeddings(embeddings)
578
579return self.get_input_embeddings()
580
581def _resize_cutoffs(self, new_num_tokens, new_emb_size, new_embedding_shapes, layer):
582embeddings = self.get_input_embeddings()
583
584for i in range(layer, len(embeddings.cutoffs)):
585embeddings.cutoffs[i] = sum(new_embedding_shapes[: i + 1])
586
587embeddings.cutoff_ends = [0] + embeddings.cutoffs
588embeddings.n_token = new_num_tokens
589
590self.config.cutoffs = embeddings.cutoffs[:-1]
591
592return embeddings.cutoffs
593
594
595@dataclass
596class TransfoXLModelOutput(ModelOutput):
597"""
598Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
599
600Args:
601last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
602Sequence of hidden-states at the output of the last layer of the model.
603mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
604Contains pre-computed hidden-states (key and values in the attention blocks).
605Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
606should not be passed as input ids as they have already been computed.
607hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
608Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
609of shape :obj:`(batch_size, sequence_length, hidden_size)`.
610
611Hidden-states of the model at the output of each layer plus the initial embedding outputs.
612attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
613Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
614:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
615
616Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
617heads.
618"""
619
620last_hidden_state: torch.FloatTensor
621mems: List[torch.FloatTensor] = None
622hidden_states: Optional[Tuple[torch.FloatTensor]] = None
623attentions: Optional[Tuple[torch.FloatTensor]] = None
624
625
626@dataclass
627class TransfoXLLMHeadModelOutput(ModelOutput):
628"""
629Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
630
631Args:
632losses (:obj:`torch.FloatTensor` of shape `(batch_size, sequence_length-1)`, `optional`, returned when ``labels`` is provided)
633Language modeling losses (not reduced).
634prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
635Prediction scores of the language modeling head (scores for each vocabulary token after SoftMax).
636mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
637Contains pre-computed hidden-states (key and values in the attention blocks).
638Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
639should not be passed as input ids as they have already been computed.
640hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
641Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
642of shape :obj:`(batch_size, sequence_length, hidden_size)`.
643
644Hidden-states of the model at the output of each layer plus the initial embedding outputs.
645attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
646Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
647:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
648
649Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
650heads.
651"""
652
653losses: Optional[torch.FloatTensor] = None
654prediction_scores: torch.FloatTensor = None
655mems: List[torch.FloatTensor] = None
656hidden_states: Optional[Tuple[torch.FloatTensor]] = None
657attentions: Optional[Tuple[torch.FloatTensor]] = None
658
659
660TRANSFO_XL_START_DOCSTRING = r"""
661
662This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
663Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
664usage and behavior.
665
666Parameters:
667config (:class:`~transformers.TransfoXLConfig`): Model configuration class with all the parameters of the model.
668Initializing with a config file does not load the weights associated with the model, only the configuration.
669Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
670"""
671
672TRANSFO_XL_INPUTS_DOCSTRING = r"""
673Args:
674input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
675Indices of input sequence tokens in the vocabulary.
676
677Indices can be obtained using :class:`transformers.TransfoXLTokenizer`.
678See :func:`transformers.PreTrainedTokenizer.encode` and
679:func:`transformers.PreTrainedTokenizer.__call__` for details.
680
681`What are input IDs? <../glossary.html#input-ids>`__
682mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
683Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
684(see `mems` output below). Can be used to speed up sequential decoding. The token ids which have their mems
685given to this model should not be passed as input ids as they have already been computed.
686head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
687Mask to nullify selected heads of the self-attention modules.
688Mask values selected in ``[0, 1]``:
689:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
690inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
691Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
692This is useful if you want more control over how to convert `input_ids` indices into associated vectors
693than the model's internal embedding lookup matrix.
694output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
695If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
696output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
697If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
698return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`):
699If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a
700plain tuple.
701"""
702
703
704@add_start_docstrings(
705"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
706TRANSFO_XL_START_DOCSTRING,
707)
708class TransfoXLModel(TransfoXLPreTrainedModel):
709def __init__(self, config):
710super().__init__(config)
711
712self.n_token = config.vocab_size
713
714self.d_embed = config.d_embed
715self.d_model = config.d_model
716self.n_head = config.n_head
717self.d_head = config.d_head
718
719self.word_emb = AdaptiveEmbedding(
720config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val
721)
722
723self.drop = nn.Dropout(config.dropout)
724
725self.n_layer = config.n_layer
726
727self.tgt_len = config.tgt_len
728self.mem_len = config.mem_len
729self.ext_len = config.ext_len
730self.max_klen = config.tgt_len + config.ext_len + config.mem_len
731
732self.attn_type = config.attn_type
733
734if not config.untie_r:
735self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
736self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
737
738self.layers = nn.ModuleList()
739if config.attn_type == 0: # the default attention
740for i in range(config.n_layer):
741self.layers.append(
742RelPartialLearnableDecoderLayer(
743config.n_head,
744config.d_model,
745config.d_head,
746config.d_inner,
747config.dropout,
748tgt_len=config.tgt_len,
749ext_len=config.ext_len,
750mem_len=config.mem_len,
751dropatt=config.dropatt,
752pre_lnorm=config.pre_lnorm,
753r_w_bias=None if config.untie_r else self.r_w_bias,
754r_r_bias=None if config.untie_r else self.r_r_bias,
755layer_norm_epsilon=config.layer_norm_epsilon,
756)
757)
758else: # learnable embeddings and absolute embeddings are not used in our pretrained checkpoints
759raise NotImplementedError # Removed them to avoid maintaining dead code
760
761self.same_length = config.same_length
762self.clamp_len = config.clamp_len
763
764if self.attn_type == 0: # default attention
765self.pos_emb = PositionalEmbedding(self.d_model)
766else: # learnable embeddings and absolute embeddings
767raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
768
769self.init_weights()
770
771def get_input_embeddings(self):
772return self.word_emb
773
774def set_input_embeddings(self, new_embeddings):
775self.word_emb = new_embeddings
776
777def backward_compatible(self):
778self.sample_softmax = -1
779
780def reset_length(self, tgt_len, ext_len, mem_len):
781self.tgt_len = tgt_len
782self.mem_len = mem_len
783self.ext_len = ext_len
784
785def _prune_heads(self, heads):
786logger.info("Head pruning is not implemented for Transformer-XL model")
787pass
788
789def init_mems(self, bsz):
790if self.mem_len > 0:
791mems = []
792param = next(self.parameters())
793for i in range(self.n_layer):
794empty = torch.zeros(self.mem_len, bsz, self.config.d_model, dtype=param.dtype, device=param.device)
795mems.append(empty)
796
797return mems
798else:
799return None
800
801def _update_mems(self, hids, mems, mlen, qlen):
802# does not deal with None
803if mems is None:
804return None
805
806# mems is not None
807assert len(hids) == len(mems), "len(hids) != len(mems)"
808
809# There are `mlen + qlen` steps that can be cached into mems
810# For the next step, the last `ext_len` of the `qlen` tokens
811# will be used as the extended context. Hence, we only cache
812# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
813# to `mlen + qlen - self.ext_len`.
814with torch.no_grad():
815new_mems = []
816end_idx = mlen + max(0, qlen - 0 - self.ext_len)
817beg_idx = max(0, end_idx - self.mem_len)
818for i in range(len(hids)):
819
820cat = torch.cat([mems[i], hids[i]], dim=0)
821new_mems.append(cat[beg_idx:end_idx].detach())
822
823return new_mems
824
825@add_start_docstrings_to_callable(TRANSFO_XL_INPUTS_DOCSTRING)
826@add_code_sample_docstrings(
827tokenizer_class=_TOKENIZER_FOR_DOC,
828checkpoint="transfo-xl-wt103",
829output_type=TransfoXLModelOutput,
830config_class=_CONFIG_FOR_DOC,
831)
832def forward(
833self,
834input_ids=None,
835mems=None,
836head_mask=None,
837inputs_embeds=None,
838output_attentions=None,
839output_hidden_states=None,
840return_dict=None,
841):
842output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
843output_hidden_states = (
844output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
845)
846return_dict = return_dict if return_dict is not None else self.config.use_return_dict
847
848# the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
849# so we transpose here from shape [bsz, len] to shape [len, bsz]
850if input_ids is not None and inputs_embeds is not None:
851raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
852elif input_ids is not None:
853input_ids = input_ids.transpose(0, 1).contiguous()
854qlen, bsz = input_ids.size()
855elif inputs_embeds is not None:
856inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
857qlen, bsz = inputs_embeds.shape[0], inputs_embeds.shape[1]
858else:
859raise ValueError("You have to specify either input_ids or inputs_embeds")
860
861if mems is None:
862mems = self.init_mems(bsz)
863
864# Prepare head mask if needed
865# 1.0 in head_mask indicate we keep the head
866# attention_probs has shape bsz x n_heads x N x N
867# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
868# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
869if head_mask is not None:
870if head_mask.dim() == 1:
871head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)
872head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
873elif head_mask.dim() == 2:
874head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
875head_mask = head_mask.to(
876dtype=next(self.parameters()).dtype
877) # switch to fload if need + fp16 compatibility
878else:
879head_mask = [None] * self.n_layer
880
881if inputs_embeds is not None:
882word_emb = inputs_embeds
883else:
884word_emb = self.word_emb(input_ids)
885
886mlen = mems[0].size(0) if mems is not None else 0
887klen = mlen + qlen
888if self.same_length:
889all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
890mask_len = klen - self.mem_len
891if mask_len > 0:
892mask_shift_len = qlen - mask_len
893else:
894mask_shift_len = qlen
895dec_attn_mask = (torch.triu(all_ones, 1 + mlen) + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
896else:
897dec_attn_mask = torch.triu(word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1 + mlen)[
898:, :, None
899]
900
901hids = []
902attentions = [] if output_attentions else None
903if self.attn_type == 0: # default
904pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype)
905if self.clamp_len > 0:
906pos_seq.clamp_(max=self.clamp_len)
907pos_emb = self.pos_emb(pos_seq)
908
909core_out = self.drop(word_emb)
910pos_emb = self.drop(pos_emb)
911
912for i, layer in enumerate(self.layers):
913hids.append(core_out)
914mems_i = None if mems is None else mems[i]
915layer_outputs = layer(
916core_out,
917pos_emb,
918dec_attn_mask=dec_attn_mask,
919mems=mems_i,
920head_mask=head_mask[i],
921output_attentions=output_attentions,
922)
923core_out = layer_outputs[0]
924if output_attentions:
925attentions.append(layer_outputs[1])
926else: # learnable embeddings and absolute embeddings
927raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
928
929core_out = self.drop(core_out)
930
931new_mems = self._update_mems(hids, mems, mlen, qlen)
932
933if output_hidden_states:
934# Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
935hids.append(core_out)
936hids = tuple(t.transpose(0, 1).contiguous() for t in hids)
937else:
938hids = None
939if output_attentions:
940# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
941attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
942# We transpose back here to shape [bsz, len, hidden_dim]
943core_out = core_out.transpose(0, 1).contiguous()
944
945if not return_dict:
946return tuple(v for v in [core_out, new_mems, hids, attentions] if v is not None)
947
948return TransfoXLModelOutput(
949last_hidden_state=core_out, mems=new_mems, hidden_states=hids, attentions=attentions,
950)
951
952
953@add_start_docstrings(
954"""The Transformer-XL Model with a language modeling head on top
955(adaptive softmax with weights tied to the adaptive input embeddings)""",
956TRANSFO_XL_START_DOCSTRING,
957)
958class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
959def __init__(self, config):
960super().__init__(config)
961self.transformer = TransfoXLModel(config)
962self.sample_softmax = config.sample_softmax
963
964assert (
965self.sample_softmax <= 0
966), "Sampling from the softmax is not implemented yet. Please look at issue: #3310: https://github.com/huggingface/transformers/issues/3310"
967
968self.crit = ProjectedAdaptiveLogSoftmax(
969config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val
970)
971
972self.init_weights()
973
974def tie_weights(self):
975"""
976Run this to be sure output and input (adaptive) softmax weights are tied
977"""
978
979if self.config.tie_weight:
980for i in range(len(self.crit.out_layers)):
981self._tie_or_clone_weights(self.crit.out_layers[i], self.transformer.word_emb.emb_layers[i])
982if self.config.tie_projs:
983for i, tie_proj in enumerate(self.config.tie_projs):
984if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed:
985if self.config.torchscript:
986self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[0].clone())
987else:
988self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0]
989elif tie_proj and self.config.div_val != 1:
990if self.config.torchscript:
991self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[i].clone())
992else:
993self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]
994
995def reset_length(self, tgt_len, ext_len, mem_len):
996self.transformer.reset_length(tgt_len, ext_len, mem_len)
997
998def init_mems(self, bsz):
999return self.transformer.init_mems(bsz)
1000
1001@add_start_docstrings_to_callable(TRANSFO_XL_INPUTS_DOCSTRING)
1002@add_code_sample_docstrings(
1003tokenizer_class=_TOKENIZER_FOR_DOC,
1004checkpoint="transfo-xl-wt103",
1005output_type=TransfoXLLMHeadModelOutput,
1006config_class=_CONFIG_FOR_DOC,
1007)
1008def forward(
1009self,
1010input_ids=None,
1011mems=None,
1012head_mask=None,
1013inputs_embeds=None,
1014labels=None,
1015output_attentions=None,
1016output_hidden_states=None,
1017return_dict=None,
1018):
1019r"""
1020labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
1021Labels for language modeling.
1022Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids``
1023Indices are selected in ``[-100, 0, ..., config.vocab_size]``
1024All labels set to ``-100`` are ignored (masked), the loss is only
1025computed for labels in ``[0, ..., config.vocab_size]``
1026"""
1027return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1028if input_ids is not None:
1029bsz, tgt_len = input_ids.size(0), input_ids.size(1)
1030elif inputs_embeds is not None:
1031bsz, tgt_len = inputs_embeds.size(0), inputs_embeds.size(1)
1032else:
1033raise ValueError("You have to specify either input_ids or inputs_embeds")
1034
1035transformer_outputs = self.transformer(
1036input_ids,
1037mems=mems,
1038head_mask=head_mask,
1039inputs_embeds=inputs_embeds,
1040output_attentions=output_attentions,
1041output_hidden_states=output_hidden_states,
1042return_dict=return_dict,
1043)
1044
1045last_hidden = transformer_outputs[0]
1046pred_hid = last_hidden[:, -tgt_len:]
1047
1048softmax_output = self.crit(pred_hid, labels)
1049prediction_scores = softmax_output.view(bsz, tgt_len, -1) if labels is None else ()
1050loss = softmax_output.view(bsz, tgt_len - 1) if labels is not None else None
1051
1052if not return_dict:
1053output = (prediction_scores,) + transformer_outputs[1:]
1054return ((loss,) + output) if loss is not None else output
1055
1056return TransfoXLLMHeadModelOutput(
1057losses=loss,
1058prediction_scores=prediction_scores,
1059mems=transformer_outputs.mems,
1060hidden_states=transformer_outputs.hidden_states,
1061attentions=transformer_outputs.attentions,
1062)
1063
1064def get_output_embeddings(self):
1065""" Double-check if you are using adaptive softmax.
1066"""
1067if self.sample_softmax > 0:
1068return self.out_layer
1069else:
1070return self.crit.out_layers[-1]
1071
1072def prepare_inputs_for_generation(self, input_ids, past, **model_kwargs):
1073inputs = {}
1074
1075# if past is defined in model kwargs then use it for faster decoding
1076if past:
1077inputs["mems"] = past
1078inputs["input_ids"] = input_ids[:, -1].unsqueeze(-1)
1079else:
1080inputs["input_ids"] = input_ids
1081
1082return inputs
1083
1084def _resize_cutoffs(self, new_num_tokens, new_emb_size, new_embedding_shapes, layer):
1085new_cutoffs = super()._resize_cutoffs(new_num_tokens, new_emb_size, new_embedding_shapes, layer)
1086
1087self.crit.cutoffs = new_cutoffs
1088self.crit.cutoff_ends = [0] + new_cutoffs
1089self.crit.n_token = new_num_tokens
1090