CSS-LM

Форк
0
/
modeling_transfo_xl.py 
1089 строк · 45.2 Кб
1
# coding=utf-8
2
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
""" PyTorch Transformer XL model.
17
    Adapted from https://github.com/kimiyoung/transformer-xl.
18
    In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
19
"""
20

21

22
import logging
23
from dataclasses import dataclass
24
from typing import List, Optional, Tuple
25

26
import torch
27
import torch.nn as nn
28
import torch.nn.functional as F
29

30
from .configuration_transfo_xl import TransfoXLConfig
31
from .file_utils import ModelOutput, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
32
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax
33
from .modeling_utils import PreTrainedModel
34

35

36
logger = logging.getLogger(__name__)
37

38
_CONFIG_FOR_DOC = "TransfoXLConfig"
39
_TOKENIZER_FOR_DOC = "TransfoXLTokenizer"
40

41
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = [
42
    "transfo-xl-wt103",
43
    # See all Transformer XL models at https://huggingface.co/models?filter=transfo-xl
44
]
45

46

47
def build_tf_to_pytorch_map(model, config):
48
    """ A map of modules from TF to PyTorch.
49
        This time I use a map to keep the PyTorch model as identical to the original PyTorch model as possible.
50
    """
51
    tf_to_pt_map = {}
52

53
    if hasattr(model, "transformer"):
54
        # We are loading in a TransfoXLLMHeadModel => we will load also the Adaptive Softmax
55
        tf_to_pt_map.update(
56
            {
57
                "transformer/adaptive_softmax/cutoff_0/cluster_W": model.crit.cluster_weight,
58
                "transformer/adaptive_softmax/cutoff_0/cluster_b": model.crit.cluster_bias,
59
            }
60
        )
61
        for i, (out_l, proj_l, tie_proj) in enumerate(
62
            zip(model.crit.out_layers, model.crit.out_projs, config.tie_projs)
63
        ):
64
            layer_str = "transformer/adaptive_softmax/cutoff_%d/" % i
65
            if config.tie_weight:
66
                tf_to_pt_map.update({layer_str + "b": out_l.bias})
67
            else:
68
                raise NotImplementedError
69
                # I don't think this is implemented in the TF code
70
                tf_to_pt_map.update({layer_str + "lookup_table": out_l.weight, layer_str + "b": out_l.bias})
71
            if not tie_proj:
72
                tf_to_pt_map.update({layer_str + "proj": proj_l})
73
        # Now load the rest of the transformer
74
        model = model.transformer
75

76
    # Embeddings
77
    for i, (embed_l, proj_l) in enumerate(zip(model.word_emb.emb_layers, model.word_emb.emb_projs)):
78
        layer_str = "transformer/adaptive_embed/cutoff_%d/" % i
79
        tf_to_pt_map.update({layer_str + "lookup_table": embed_l.weight, layer_str + "proj_W": proj_l})
80

81
    # Transformer blocks
82
    for i, b in enumerate(model.layers):
83
        layer_str = "transformer/layer_%d/" % i
84
        tf_to_pt_map.update(
85
            {
86
                layer_str + "rel_attn/LayerNorm/gamma": b.dec_attn.layer_norm.weight,
87
                layer_str + "rel_attn/LayerNorm/beta": b.dec_attn.layer_norm.bias,
88
                layer_str + "rel_attn/o/kernel": b.dec_attn.o_net.weight,
89
                layer_str + "rel_attn/qkv/kernel": b.dec_attn.qkv_net.weight,
90
                layer_str + "rel_attn/r/kernel": b.dec_attn.r_net.weight,
91
                layer_str + "ff/LayerNorm/gamma": b.pos_ff.layer_norm.weight,
92
                layer_str + "ff/LayerNorm/beta": b.pos_ff.layer_norm.bias,
93
                layer_str + "ff/layer_1/kernel": b.pos_ff.CoreNet[0].weight,
94
                layer_str + "ff/layer_1/bias": b.pos_ff.CoreNet[0].bias,
95
                layer_str + "ff/layer_2/kernel": b.pos_ff.CoreNet[3].weight,
96
                layer_str + "ff/layer_2/bias": b.pos_ff.CoreNet[3].bias,
97
            }
98
        )
99

100
    # Relative positioning biases
101
    if config.untie_r:
102
        r_r_list = []
103
        r_w_list = []
104
        for b in model.layers:
105
            r_r_list.append(b.dec_attn.r_r_bias)
106
            r_w_list.append(b.dec_attn.r_w_bias)
107
    else:
108
        r_r_list = [model.r_r_bias]
109
        r_w_list = [model.r_w_bias]
110
    tf_to_pt_map.update({"transformer/r_r_bias": r_r_list, "transformer/r_w_bias": r_w_list})
111
    return tf_to_pt_map
112

113

114
def load_tf_weights_in_transfo_xl(model, config, tf_path):
115
    """ Load tf checkpoints in a pytorch model
116
    """
117
    try:
118
        import numpy as np
119
        import tensorflow as tf
120
    except ImportError:
121
        logger.error(
122
            "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
123
            "https://www.tensorflow.org/install/ for installation instructions."
124
        )
125
        raise
126
    # Build TF to PyTorch weights loading map
127
    tf_to_pt_map = build_tf_to_pytorch_map(model, config)
128

129
    # Load weights from TF model
130
    init_vars = tf.train.list_variables(tf_path)
131
    tf_weights = {}
132
    for name, shape in init_vars:
133
        logger.info("Loading TF weight {} with shape {}".format(name, shape))
134
        array = tf.train.load_variable(tf_path, name)
135
        tf_weights[name] = array
136

137
    for name, pointer in tf_to_pt_map.items():
138
        assert name in tf_weights
139
        array = tf_weights[name]
140
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
141
        # which are not required for using pretrained model
142
        if "kernel" in name or "proj" in name:
143
            array = np.transpose(array)
144
        if ("r_r_bias" in name or "r_w_bias" in name) and len(pointer) > 1:
145
            # Here we will split the TF weights
146
            assert len(pointer) == array.shape[0]
147
            for i, p_i in enumerate(pointer):
148
                arr_i = array[i, ...]
149
                try:
150
                    assert p_i.shape == arr_i.shape
151
                except AssertionError as e:
152
                    e.args += (p_i.shape, arr_i.shape)
153
                    raise
154
                logger.info("Initialize PyTorch weight {} for layer {}".format(name, i))
155
                p_i.data = torch.from_numpy(arr_i)
156
        else:
157
            try:
158
                assert pointer.shape == array.shape
159
            except AssertionError as e:
160
                e.args += (pointer.shape, array.shape)
161
                raise
162
            logger.info("Initialize PyTorch weight {}".format(name))
163
            pointer.data = torch.from_numpy(array)
164
        tf_weights.pop(name, None)
165
        tf_weights.pop(name + "/Adam", None)
166
        tf_weights.pop(name + "/Adam_1", None)
167

168
    logger.info("Weights not copied to PyTorch model: {}".format(", ".join(tf_weights.keys())))
169
    return model
170

171

172
class PositionalEmbedding(nn.Module):
173
    def __init__(self, demb):
174
        super().__init__()
175

176
        self.demb = demb
177

178
        inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
179
        self.register_buffer("inv_freq", inv_freq)
180

181
    def forward(self, pos_seq, bsz=None):
182
        sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
183
        pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
184

185
        if bsz is not None:
186
            return pos_emb[:, None, :].expand(-1, bsz, -1)
187
        else:
188
            return pos_emb[:, None, :]
189

190

191
class PositionwiseFF(nn.Module):
192
    def __init__(self, d_model, d_inner, dropout, pre_lnorm=False, layer_norm_epsilon=1e-5):
193
        super().__init__()
194

195
        self.d_model = d_model
196
        self.d_inner = d_inner
197
        self.dropout = dropout
198

199
        self.CoreNet = nn.Sequential(
200
            nn.Linear(d_model, d_inner),
201
            nn.ReLU(inplace=True),
202
            nn.Dropout(dropout),
203
            nn.Linear(d_inner, d_model),
204
            nn.Dropout(dropout),
205
        )
206

207
        self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)
208

209
        self.pre_lnorm = pre_lnorm
210

211
    def forward(self, inp):
212
        if self.pre_lnorm:
213
            # layer normalization + positionwise feed-forward
214
            core_out = self.CoreNet(self.layer_norm(inp))
215

216
            # residual connection
217
            output = core_out + inp
218
        else:
219
            # positionwise feed-forward
220
            core_out = self.CoreNet(inp)
221

222
            # residual connection + layer normalization
223
            output = self.layer_norm(inp + core_out)
224

225
        return output
226

227

228
class RelPartialLearnableMultiHeadAttn(nn.Module):
229
    def __init__(
230
        self,
231
        n_head,
232
        d_model,
233
        d_head,
234
        dropout,
235
        dropatt=0,
236
        tgt_len=None,
237
        ext_len=None,
238
        mem_len=None,
239
        pre_lnorm=False,
240
        r_r_bias=None,
241
        r_w_bias=None,
242
        layer_norm_epsilon=1e-5,
243
    ):
244
        super().__init__()
245

246
        self.n_head = n_head
247
        self.d_model = d_model
248
        self.d_head = d_head
249
        self.dropout = dropout
250

251
        self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)
252

253
        self.drop = nn.Dropout(dropout)
254
        self.dropatt = nn.Dropout(dropatt)
255
        self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
256

257
        self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon)
258

259
        self.scale = 1 / (d_head ** 0.5)
260

261
        self.pre_lnorm = pre_lnorm
262

263
        if r_r_bias is None or r_w_bias is None:  # Biases are not shared
264
            self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
265
            self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
266
        else:
267
            self.r_r_bias = r_r_bias
268
            self.r_w_bias = r_w_bias
269

270
        self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
271

272
    def _rel_shift(self, x):
273
        zero_pad_shape = (x.size(0), 1) + x.size()[2:]
274
        zero_pad = torch.zeros(zero_pad_shape, device=x.device, dtype=x.dtype)
275
        x_padded = torch.cat([zero_pad, x], dim=1)
276

277
        x_padded_shape = (x.size(1) + 1, x.size(0)) + x.size()[2:]
278
        x_padded = x_padded.view(*x_padded_shape)
279

280
        x = x_padded[1:].view_as(x)
281

282
        return x
283

284
    def forward(self, w, r, attn_mask=None, mems=None, head_mask=None, output_attentions=False):
285
        qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
286

287
        if mems is not None:
288
            cat = torch.cat([mems, w], 0)
289
            if self.pre_lnorm:
290
                w_heads = self.qkv_net(self.layer_norm(cat))
291
            else:
292
                w_heads = self.qkv_net(cat)
293
            r_head_k = self.r_net(r)
294

295
            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
296
            w_head_q = w_head_q[-qlen:]
297
        else:
298
            if self.pre_lnorm:
299
                w_heads = self.qkv_net(self.layer_norm(w))
300
            else:
301
                w_heads = self.qkv_net(w)
302
            r_head_k = self.r_net(r)
303

304
            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
305

306
        klen = w_head_k.size(0)
307

308
        w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)  # qlen x bsz x n_head x d_head
309
        w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)  # qlen x bsz x n_head x d_head
310
        w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)  # qlen x bsz x n_head x d_head
311

312
        r_head_k = r_head_k.view(rlen, self.n_head, self.d_head)  # qlen x n_head x d_head
313

314
        # compute attention score
315
        rw_head_q = w_head_q + self.r_w_bias  # qlen x bsz x n_head x d_head
316
        AC = torch.einsum("ibnd,jbnd->ijbn", (rw_head_q, w_head_k))  # qlen x klen x bsz x n_head
317

318
        rr_head_q = w_head_q + self.r_r_bias
319
        BD = torch.einsum("ibnd,jnd->ijbn", (rr_head_q, r_head_k))  # qlen x klen x bsz x n_head
320
        BD = self._rel_shift(BD)
321

322
        # [qlen x klen x bsz x n_head]
323
        attn_score = AC + BD
324
        attn_score.mul_(self.scale)
325

326
        # compute attention probability
327
        if attn_mask is not None and torch.sum(attn_mask).item():
328
            attn_mask = attn_mask == 1  # Switch to bool
329
            if attn_mask.dim() == 2:
330
                if next(self.parameters()).dtype == torch.float16:
331
                    attn_score = (
332
                        attn_score.float().masked_fill(attn_mask[None, :, :, None], -65000).type_as(attn_score)
333
                    )
334
                else:
335
                    attn_score = attn_score.float().masked_fill(attn_mask[None, :, :, None], -1e30).type_as(attn_score)
336
            elif attn_mask.dim() == 3:
337
                if next(self.parameters()).dtype == torch.float16:
338
                    attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], -65000).type_as(attn_score)
339
                else:
340
                    attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], -1e30).type_as(attn_score)
341

342
        # [qlen x klen x bsz x n_head]
343
        attn_prob = F.softmax(attn_score, dim=1)
344
        attn_prob = self.dropatt(attn_prob)
345

346
        # Mask heads if we want to
347
        if head_mask is not None:
348
            attn_prob = attn_prob * head_mask
349

350
        # compute attention vector
351
        attn_vec = torch.einsum("ijbn,jbnd->ibnd", (attn_prob, w_head_v))
352

353
        # [qlen x bsz x n_head x d_head]
354
        attn_vec = attn_vec.contiguous().view(attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
355

356
        # linear projection
357
        attn_out = self.o_net(attn_vec)
358
        attn_out = self.drop(attn_out)
359

360
        if self.pre_lnorm:
361
            # residual connection
362
            outputs = [w + attn_out]
363
        else:
364
            # residual connection + layer normalization
365
            outputs = [self.layer_norm(w + attn_out)]
366

367
        if output_attentions:
368
            outputs.append(attn_prob)
369

370
        return outputs
371

372

373
class RelPartialLearnableDecoderLayer(nn.Module):
374
    def __init__(self, n_head, d_model, d_head, d_inner, dropout, layer_norm_epsilon=1e-5, **kwargs):
375
        super().__init__()
376

377
        self.dec_attn = RelPartialLearnableMultiHeadAttn(
378
            n_head, d_model, d_head, dropout, layer_norm_epsilon=layer_norm_epsilon, **kwargs
379
        )
380
        self.pos_ff = PositionwiseFF(
381
            d_model, d_inner, dropout, pre_lnorm=kwargs.get("pre_lnorm"), layer_norm_epsilon=layer_norm_epsilon
382
        )
383

384
    def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None, output_attentions=False):
385

386
        attn_outputs = self.dec_attn(
387
            dec_inp, r, attn_mask=dec_attn_mask, mems=mems, head_mask=head_mask, output_attentions=output_attentions,
388
        )
389
        ff_output = self.pos_ff(attn_outputs[0])
390

391
        outputs = [ff_output] + attn_outputs[1:]
392

393
        return outputs
394

395

396
class AdaptiveEmbedding(nn.Module):
397
    def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, sample_softmax=False):
398
        super().__init__()
399

400
        self.n_token = n_token
401
        self.d_embed = d_embed
402

403
        self.cutoffs = cutoffs + [n_token]
404
        self.div_val = div_val
405
        self.d_proj = d_proj
406

407
        self.emb_scale = d_proj ** 0.5
408

409
        self.cutoff_ends = [0] + self.cutoffs
410

411
        self.emb_layers = nn.ModuleList()
412
        self.emb_projs = nn.ParameterList()
413
        if div_val == 1:
414
            self.emb_layers.append(nn.Embedding(n_token, d_embed, sparse=sample_softmax > 0))
415
            if d_proj != d_embed:
416
                self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed)))
417
        else:
418
            for i in range(len(self.cutoffs)):
419
                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
420
                d_emb_i = d_embed // (div_val ** i)
421
                self.emb_layers.append(nn.Embedding(r_idx - l_idx, d_emb_i))
422
                self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i)))
423

424
    def forward(self, inp):
425
        if self.div_val == 1:
426
            embed = self.emb_layers[0](inp)
427
            if self.d_proj != self.d_embed:
428
                embed = F.linear(embed, self.emb_projs[0])
429
        else:
430
            param = next(self.parameters())
431
            inp_flat = inp.view(-1)
432
            emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], dtype=param.dtype, device=param.device)
433
            for i in range(len(self.cutoffs)):
434
                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
435

436
                mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
437
                indices_i = mask_i.nonzero().squeeze()
438

439
                if indices_i.numel() == 0:
440
                    continue
441

442
                inp_i = inp_flat.index_select(0, indices_i) - l_idx
443
                emb_i = self.emb_layers[i](inp_i)
444
                emb_i = F.linear(emb_i, self.emb_projs[i])
445

446
                emb_flat.index_copy_(0, indices_i, emb_i)
447

448
            embed_shape = inp.size() + (self.d_proj,)
449
            embed = emb_flat.view(embed_shape)
450

451
        embed.mul_(self.emb_scale)
452

453
        return embed
454

455

456
class TransfoXLPreTrainedModel(PreTrainedModel):
457
    """ An abstract class to handle weights initialization and
458
        a simple interface for downloading and loading pretrained models.
459
    """
460

461
    config_class = TransfoXLConfig
462
    load_tf_weights = load_tf_weights_in_transfo_xl
463
    base_model_prefix = "transformer"
464

465
    def _init_weight(self, weight):
466
        if self.config.init == "uniform":
467
            nn.init.uniform_(weight, -self.config.init_range, self.config.init_range)
468
        elif self.config.init == "normal":
469
            nn.init.normal_(weight, 0.0, self.config.init_std)
470

471
    def _init_bias(self, bias):
472
        nn.init.constant_(bias, 0.0)
473

474
    def _init_weights(self, m):
475
        """ Initialize the weights.
476
        """
477
        classname = m.__class__.__name__
478
        if classname.find("Linear") != -1:
479
            if hasattr(m, "weight") and m.weight is not None:
480
                self._init_weight(m.weight)
481
            if hasattr(m, "bias") and m.bias is not None:
482
                self._init_bias(m.bias)
483
        elif classname.find("AdaptiveEmbedding") != -1:
484
            if hasattr(m, "emb_projs"):
485
                for i in range(len(m.emb_projs)):
486
                    if m.emb_projs[i] is not None:
487
                        nn.init.normal_(m.emb_projs[i], 0.0, self.config.proj_init_std)
488
        elif classname.find("Embedding") != -1:
489
            if hasattr(m, "weight"):
490
                self._init_weight(m.weight)
491
        elif classname.find("ProjectedAdaptiveLogSoftmax") != -1:
492
            if hasattr(m, "cluster_weight") and m.cluster_weight is not None:
493
                self._init_weight(m.cluster_weight)
494
            if hasattr(m, "cluster_bias") and m.cluster_bias is not None:
495
                self._init_bias(m.cluster_bias)
496
            if hasattr(m, "out_projs"):
497
                for i in range(len(m.out_projs)):
498
                    if m.out_projs[i] is not None:
499
                        nn.init.normal_(m.out_projs[i], 0.0, self.config.proj_init_std)
500
        elif classname.find("LayerNorm") != -1:
501
            if hasattr(m, "weight"):
502
                nn.init.normal_(m.weight, 1.0, self.config.init_std)
503
            if hasattr(m, "bias") and m.bias is not None:
504
                self._init_bias(m.bias)
505
        else:
506
            if hasattr(m, "r_emb"):
507
                self._init_weight(m.r_emb)
508
            if hasattr(m, "r_w_bias"):
509
                self._init_weight(m.r_w_bias)
510
            if hasattr(m, "r_r_bias"):
511
                self._init_weight(m.r_r_bias)
512
            if hasattr(m, "r_bias"):
513
                self._init_bias(m.r_bias)
514

515
    def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, layer: Optional[int] = -1):
516
        """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
517
        Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
518

519
        Arguments:
520

521
            new_num_tokens: (`optional`) int:
522
                New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
523
                If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
524
            layer: (`optional`) int:
525
                Layer of the `AdaptiveEmbedding` where the resizing should be done. Per default the last layer will be resized.
526
                Be aware that when resizing other than the last layer, you have to ensure that the new token(s) in the tokenizer are at the corresponding position.
527

528
        Return: ``torch.nn.Embeddings``
529
            Pointer to the input tokens Embeddings Module of the model
530
        """
531
        base_model = getattr(self, self.base_model_prefix, self)  # get the base model if needed
532

533
        if new_num_tokens is None:
534
            return self.get_input_embeddings()
535

536
        new_num_tokens_layer, layer = self._get_new_num_tokens_layer(new_num_tokens, layer)
537
        assert new_num_tokens_layer > 0, "The size of the new embedding layer cannot be 0 or less"
538
        model_embeds = base_model._resize_token_embeddings(new_num_tokens_layer, layer)
539

540
        # Update base model and current model config
541
        self.config.vocab_size = new_num_tokens
542
        base_model.vocab_size = new_num_tokens
543
        base_model.n_token = new_num_tokens
544

545
        new_embedding_shapes = self._get_embedding_shapes()
546
        self._resize_cutoffs(new_num_tokens, new_num_tokens_layer, new_embedding_shapes, layer)
547

548
        # Tie weights again if needed
549
        self.tie_weights()
550

551
        return model_embeds
552

553
    def _get_new_num_tokens_layer(self, new_num_tokens, layer):
554
        embeddings = self.get_input_embeddings()
555
        if layer == -1:
556
            layer = len(embeddings.emb_layers) - 1
557
        assert 0 <= layer <= len(embeddings.emb_layers) - 1
558

559
        new_num_tokens_layer = (
560
            new_num_tokens
561
            - sum([emb.weight.shape[0] for emb in embeddings.emb_layers[:layer]])
562
            - sum([emb.weight.shape[0] for emb in embeddings.emb_layers[layer + 1 :]])
563
        )
564
        return new_num_tokens_layer, layer
565

566
    def _get_embedding_shapes(self):
567
        embeddings = self.get_input_embeddings()
568
        return [emb.weight.shape[0] for emb in embeddings.emb_layers]
569

570
    def _resize_token_embeddings(self, new_num_tokens, layer=-1):
571
        embeddings = self.get_input_embeddings()
572
        if new_num_tokens is None:
573
            return embeddings
574
        new_embeddings_layer = self._get_resized_embeddings(embeddings.emb_layers[layer], new_num_tokens)
575
        embeddings.emb_layers[layer] = new_embeddings_layer
576

577
        self.set_input_embeddings(embeddings)
578

579
        return self.get_input_embeddings()
580

581
    def _resize_cutoffs(self, new_num_tokens, new_emb_size, new_embedding_shapes, layer):
582
        embeddings = self.get_input_embeddings()
583

584
        for i in range(layer, len(embeddings.cutoffs)):
585
            embeddings.cutoffs[i] = sum(new_embedding_shapes[: i + 1])
586

587
        embeddings.cutoff_ends = [0] + embeddings.cutoffs
588
        embeddings.n_token = new_num_tokens
589

590
        self.config.cutoffs = embeddings.cutoffs[:-1]
591

592
        return embeddings.cutoffs
593

594

595
@dataclass
596
class TransfoXLModelOutput(ModelOutput):
597
    """
598
    Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
599

600
    Args:
601
        last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
602
            Sequence of hidden-states at the output of the last layer of the model.
603
        mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
604
            Contains pre-computed hidden-states (key and values in the attention blocks).
605
            Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
606
            should not be passed as input ids as they have already been computed.
607
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
608
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
609
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
610

611
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
612
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
613
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
614
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
615

616
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
617
            heads.
618
    """
619

620
    last_hidden_state: torch.FloatTensor
621
    mems: List[torch.FloatTensor] = None
622
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
623
    attentions: Optional[Tuple[torch.FloatTensor]] = None
624

625

626
@dataclass
627
class TransfoXLLMHeadModelOutput(ModelOutput):
628
    """
629
    Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
630

631
    Args:
632
        losses (:obj:`torch.FloatTensor` of shape `(batch_size, sequence_length-1)`, `optional`, returned when ``labels`` is provided)
633
            Language modeling losses (not reduced).
634
        prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
635
            Prediction scores of the language modeling head (scores for each vocabulary token after SoftMax).
636
        mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
637
            Contains pre-computed hidden-states (key and values in the attention blocks).
638
            Can be used (see `mems` input) to speed up sequential decoding. The token ids which have their past given to this model
639
            should not be passed as input ids as they have already been computed.
640
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
641
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
642
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
643

644
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
645
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
646
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
647
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
648

649
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
650
            heads.
651
    """
652

653
    losses: Optional[torch.FloatTensor] = None
654
    prediction_scores: torch.FloatTensor = None
655
    mems: List[torch.FloatTensor] = None
656
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
657
    attentions: Optional[Tuple[torch.FloatTensor]] = None
658

659

660
TRANSFO_XL_START_DOCSTRING = r"""
661

662
    This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
663
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
664
    usage and behavior.
665

666
    Parameters:
667
        config (:class:`~transformers.TransfoXLConfig`): Model configuration class with all the parameters of the model.
668
            Initializing with a config file does not load the weights associated with the model, only the configuration.
669
            Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
670
"""
671

672
TRANSFO_XL_INPUTS_DOCSTRING = r"""
673
    Args:
674
        input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
675
            Indices of input sequence tokens in the vocabulary.
676

677
            Indices can be obtained using :class:`transformers.TransfoXLTokenizer`.
678
            See :func:`transformers.PreTrainedTokenizer.encode` and
679
            :func:`transformers.PreTrainedTokenizer.__call__` for details.
680

681
            `What are input IDs? <../glossary.html#input-ids>`__
682
        mems (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
683
            Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
684
            (see `mems` output below). Can be used to speed up sequential decoding. The token ids which have their mems
685
            given to this model should not be passed as input ids as they have already been computed.
686
        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
687
            Mask to nullify selected heads of the self-attention modules.
688
            Mask values selected in ``[0, 1]``:
689
            :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
690
        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
691
            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
692
            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
693
            than the model's internal embedding lookup matrix.
694
        output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
695
            If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
696
        output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
697
            If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
698
        return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`):
699
            If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a
700
            plain tuple.
701
"""
702

703

704
@add_start_docstrings(
705
    "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
706
    TRANSFO_XL_START_DOCSTRING,
707
)
708
class TransfoXLModel(TransfoXLPreTrainedModel):
709
    def __init__(self, config):
710
        super().__init__(config)
711

712
        self.n_token = config.vocab_size
713

714
        self.d_embed = config.d_embed
715
        self.d_model = config.d_model
716
        self.n_head = config.n_head
717
        self.d_head = config.d_head
718

719
        self.word_emb = AdaptiveEmbedding(
720
            config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val
721
        )
722

723
        self.drop = nn.Dropout(config.dropout)
724

725
        self.n_layer = config.n_layer
726

727
        self.tgt_len = config.tgt_len
728
        self.mem_len = config.mem_len
729
        self.ext_len = config.ext_len
730
        self.max_klen = config.tgt_len + config.ext_len + config.mem_len
731

732
        self.attn_type = config.attn_type
733

734
        if not config.untie_r:
735
            self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
736
            self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
737

738
        self.layers = nn.ModuleList()
739
        if config.attn_type == 0:  # the default attention
740
            for i in range(config.n_layer):
741
                self.layers.append(
742
                    RelPartialLearnableDecoderLayer(
743
                        config.n_head,
744
                        config.d_model,
745
                        config.d_head,
746
                        config.d_inner,
747
                        config.dropout,
748
                        tgt_len=config.tgt_len,
749
                        ext_len=config.ext_len,
750
                        mem_len=config.mem_len,
751
                        dropatt=config.dropatt,
752
                        pre_lnorm=config.pre_lnorm,
753
                        r_w_bias=None if config.untie_r else self.r_w_bias,
754
                        r_r_bias=None if config.untie_r else self.r_r_bias,
755
                        layer_norm_epsilon=config.layer_norm_epsilon,
756
                    )
757
                )
758
        else:  # learnable embeddings and absolute embeddings are not used in our pretrained checkpoints
759
            raise NotImplementedError  # Removed them to avoid maintaining dead code
760

761
        self.same_length = config.same_length
762
        self.clamp_len = config.clamp_len
763

764
        if self.attn_type == 0:  # default attention
765
            self.pos_emb = PositionalEmbedding(self.d_model)
766
        else:  # learnable embeddings and absolute embeddings
767
            raise NotImplementedError  # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
768

769
        self.init_weights()
770

771
    def get_input_embeddings(self):
772
        return self.word_emb
773

774
    def set_input_embeddings(self, new_embeddings):
775
        self.word_emb = new_embeddings
776

777
    def backward_compatible(self):
778
        self.sample_softmax = -1
779

780
    def reset_length(self, tgt_len, ext_len, mem_len):
781
        self.tgt_len = tgt_len
782
        self.mem_len = mem_len
783
        self.ext_len = ext_len
784

785
    def _prune_heads(self, heads):
786
        logger.info("Head pruning is not implemented for Transformer-XL model")
787
        pass
788

789
    def init_mems(self, bsz):
790
        if self.mem_len > 0:
791
            mems = []
792
            param = next(self.parameters())
793
            for i in range(self.n_layer):
794
                empty = torch.zeros(self.mem_len, bsz, self.config.d_model, dtype=param.dtype, device=param.device)
795
                mems.append(empty)
796

797
            return mems
798
        else:
799
            return None
800

801
    def _update_mems(self, hids, mems, mlen, qlen):
802
        # does not deal with None
803
        if mems is None:
804
            return None
805

806
        # mems is not None
807
        assert len(hids) == len(mems), "len(hids) != len(mems)"
808

809
        # There are `mlen + qlen` steps that can be cached into mems
810
        # For the next step, the last `ext_len` of the `qlen` tokens
811
        # will be used as the extended context. Hence, we only cache
812
        # the tokens from `mlen + qlen - self.ext_len - self.mem_len`
813
        # to `mlen + qlen - self.ext_len`.
814
        with torch.no_grad():
815
            new_mems = []
816
            end_idx = mlen + max(0, qlen - 0 - self.ext_len)
817
            beg_idx = max(0, end_idx - self.mem_len)
818
            for i in range(len(hids)):
819

820
                cat = torch.cat([mems[i], hids[i]], dim=0)
821
                new_mems.append(cat[beg_idx:end_idx].detach())
822

823
        return new_mems
824

825
    @add_start_docstrings_to_callable(TRANSFO_XL_INPUTS_DOCSTRING)
826
    @add_code_sample_docstrings(
827
        tokenizer_class=_TOKENIZER_FOR_DOC,
828
        checkpoint="transfo-xl-wt103",
829
        output_type=TransfoXLModelOutput,
830
        config_class=_CONFIG_FOR_DOC,
831
    )
832
    def forward(
833
        self,
834
        input_ids=None,
835
        mems=None,
836
        head_mask=None,
837
        inputs_embeds=None,
838
        output_attentions=None,
839
        output_hidden_states=None,
840
        return_dict=None,
841
    ):
842
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
843
        output_hidden_states = (
844
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
845
        )
846
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
847

848
        # the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library
849
        # so we transpose here from shape [bsz, len] to shape [len, bsz]
850
        if input_ids is not None and inputs_embeds is not None:
851
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
852
        elif input_ids is not None:
853
            input_ids = input_ids.transpose(0, 1).contiguous()
854
            qlen, bsz = input_ids.size()
855
        elif inputs_embeds is not None:
856
            inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
857
            qlen, bsz = inputs_embeds.shape[0], inputs_embeds.shape[1]
858
        else:
859
            raise ValueError("You have to specify either input_ids or inputs_embeds")
860

861
        if mems is None:
862
            mems = self.init_mems(bsz)
863

864
        # Prepare head mask if needed
865
        # 1.0 in head_mask indicate we keep the head
866
        # attention_probs has shape bsz x n_heads x N x N
867
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
868
        # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
869
        if head_mask is not None:
870
            if head_mask.dim() == 1:
871
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)
872
                head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
873
            elif head_mask.dim() == 2:
874
                head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
875
            head_mask = head_mask.to(
876
                dtype=next(self.parameters()).dtype
877
            )  # switch to fload if need + fp16 compatibility
878
        else:
879
            head_mask = [None] * self.n_layer
880

881
        if inputs_embeds is not None:
882
            word_emb = inputs_embeds
883
        else:
884
            word_emb = self.word_emb(input_ids)
885

886
        mlen = mems[0].size(0) if mems is not None else 0
887
        klen = mlen + qlen
888
        if self.same_length:
889
            all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
890
            mask_len = klen - self.mem_len
891
            if mask_len > 0:
892
                mask_shift_len = qlen - mask_len
893
            else:
894
                mask_shift_len = qlen
895
            dec_attn_mask = (torch.triu(all_ones, 1 + mlen) + torch.tril(all_ones, -mask_shift_len))[:, :, None]  # -1
896
        else:
897
            dec_attn_mask = torch.triu(word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1 + mlen)[
898
                :, :, None
899
            ]
900

901
        hids = []
902
        attentions = [] if output_attentions else None
903
        if self.attn_type == 0:  # default
904
            pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype)
905
            if self.clamp_len > 0:
906
                pos_seq.clamp_(max=self.clamp_len)
907
            pos_emb = self.pos_emb(pos_seq)
908

909
            core_out = self.drop(word_emb)
910
            pos_emb = self.drop(pos_emb)
911

912
            for i, layer in enumerate(self.layers):
913
                hids.append(core_out)
914
                mems_i = None if mems is None else mems[i]
915
                layer_outputs = layer(
916
                    core_out,
917
                    pos_emb,
918
                    dec_attn_mask=dec_attn_mask,
919
                    mems=mems_i,
920
                    head_mask=head_mask[i],
921
                    output_attentions=output_attentions,
922
                )
923
                core_out = layer_outputs[0]
924
                if output_attentions:
925
                    attentions.append(layer_outputs[1])
926
        else:  # learnable embeddings and absolute embeddings
927
            raise NotImplementedError  # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
928

929
        core_out = self.drop(core_out)
930

931
        new_mems = self._update_mems(hids, mems, mlen, qlen)
932

933
        if output_hidden_states:
934
            # Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
935
            hids.append(core_out)
936
            hids = tuple(t.transpose(0, 1).contiguous() for t in hids)
937
        else:
938
            hids = None
939
        if output_attentions:
940
            # Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
941
            attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
942
        # We transpose back here to shape [bsz, len, hidden_dim]
943
        core_out = core_out.transpose(0, 1).contiguous()
944

945
        if not return_dict:
946
            return tuple(v for v in [core_out, new_mems, hids, attentions] if v is not None)
947

948
        return TransfoXLModelOutput(
949
            last_hidden_state=core_out, mems=new_mems, hidden_states=hids, attentions=attentions,
950
        )
951

952

953
@add_start_docstrings(
954
    """The Transformer-XL Model with a language modeling head on top
955
    (adaptive softmax with weights tied to the adaptive input embeddings)""",
956
    TRANSFO_XL_START_DOCSTRING,
957
)
958
class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
959
    def __init__(self, config):
960
        super().__init__(config)
961
        self.transformer = TransfoXLModel(config)
962
        self.sample_softmax = config.sample_softmax
963

964
        assert (
965
            self.sample_softmax <= 0
966
        ), "Sampling from the softmax is not implemented yet. Please look at issue: #3310: https://github.com/huggingface/transformers/issues/3310"
967

968
        self.crit = ProjectedAdaptiveLogSoftmax(
969
            config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val
970
        )
971

972
        self.init_weights()
973

974
    def tie_weights(self):
975
        """
976
        Run this to be sure output and input (adaptive) softmax weights are tied
977
        """
978

979
        if self.config.tie_weight:
980
            for i in range(len(self.crit.out_layers)):
981
                self._tie_or_clone_weights(self.crit.out_layers[i], self.transformer.word_emb.emb_layers[i])
982
        if self.config.tie_projs:
983
            for i, tie_proj in enumerate(self.config.tie_projs):
984
                if tie_proj and self.config.div_val == 1 and self.config.d_model != self.config.d_embed:
985
                    if self.config.torchscript:
986
                        self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[0].clone())
987
                    else:
988
                        self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[0]
989
                elif tie_proj and self.config.div_val != 1:
990
                    if self.config.torchscript:
991
                        self.crit.out_projs[i] = nn.Parameter(self.transformer.word_emb.emb_projs[i].clone())
992
                    else:
993
                        self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]
994

995
    def reset_length(self, tgt_len, ext_len, mem_len):
996
        self.transformer.reset_length(tgt_len, ext_len, mem_len)
997

998
    def init_mems(self, bsz):
999
        return self.transformer.init_mems(bsz)
1000

1001
    @add_start_docstrings_to_callable(TRANSFO_XL_INPUTS_DOCSTRING)
1002
    @add_code_sample_docstrings(
1003
        tokenizer_class=_TOKENIZER_FOR_DOC,
1004
        checkpoint="transfo-xl-wt103",
1005
        output_type=TransfoXLLMHeadModelOutput,
1006
        config_class=_CONFIG_FOR_DOC,
1007
    )
1008
    def forward(
1009
        self,
1010
        input_ids=None,
1011
        mems=None,
1012
        head_mask=None,
1013
        inputs_embeds=None,
1014
        labels=None,
1015
        output_attentions=None,
1016
        output_hidden_states=None,
1017
        return_dict=None,
1018
    ):
1019
        r"""
1020
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
1021
            Labels for language modeling.
1022
            Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids``
1023
            Indices are selected in ``[-100, 0, ..., config.vocab_size]``
1024
            All labels set to ``-100`` are ignored (masked), the loss is only
1025
            computed for labels in ``[0, ..., config.vocab_size]``
1026
        """
1027
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1028
        if input_ids is not None:
1029
            bsz, tgt_len = input_ids.size(0), input_ids.size(1)
1030
        elif inputs_embeds is not None:
1031
            bsz, tgt_len = inputs_embeds.size(0), inputs_embeds.size(1)
1032
        else:
1033
            raise ValueError("You have to specify either input_ids or inputs_embeds")
1034

1035
        transformer_outputs = self.transformer(
1036
            input_ids,
1037
            mems=mems,
1038
            head_mask=head_mask,
1039
            inputs_embeds=inputs_embeds,
1040
            output_attentions=output_attentions,
1041
            output_hidden_states=output_hidden_states,
1042
            return_dict=return_dict,
1043
        )
1044

1045
        last_hidden = transformer_outputs[0]
1046
        pred_hid = last_hidden[:, -tgt_len:]
1047

1048
        softmax_output = self.crit(pred_hid, labels)
1049
        prediction_scores = softmax_output.view(bsz, tgt_len, -1) if labels is None else ()
1050
        loss = softmax_output.view(bsz, tgt_len - 1) if labels is not None else None
1051

1052
        if not return_dict:
1053
            output = (prediction_scores,) + transformer_outputs[1:]
1054
            return ((loss,) + output) if loss is not None else output
1055

1056
        return TransfoXLLMHeadModelOutput(
1057
            losses=loss,
1058
            prediction_scores=prediction_scores,
1059
            mems=transformer_outputs.mems,
1060
            hidden_states=transformer_outputs.hidden_states,
1061
            attentions=transformer_outputs.attentions,
1062
        )
1063

1064
    def get_output_embeddings(self):
1065
        """ Double-check if you are using adaptive softmax.
1066
        """
1067
        if self.sample_softmax > 0:
1068
            return self.out_layer
1069
        else:
1070
            return self.crit.out_layers[-1]
1071

1072
    def prepare_inputs_for_generation(self, input_ids, past, **model_kwargs):
1073
        inputs = {}
1074

1075
        # if past is defined in model kwargs then use it for faster decoding
1076
        if past:
1077
            inputs["mems"] = past
1078
            inputs["input_ids"] = input_ids[:, -1].unsqueeze(-1)
1079
        else:
1080
            inputs["input_ids"] = input_ids
1081

1082
        return inputs
1083

1084
    def _resize_cutoffs(self, new_num_tokens, new_emb_size, new_embedding_shapes, layer):
1085
        new_cutoffs = super()._resize_cutoffs(new_num_tokens, new_emb_size, new_embedding_shapes, layer)
1086

1087
        self.crit.cutoffs = new_cutoffs
1088
        self.crit.cutoff_ends = [0] + new_cutoffs
1089
        self.crit.n_token = new_num_tokens
1090

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.