CSS-LM

Форк
0
/
modeling_transfo_xl_utilities.py 
249 строк · 10.4 Кб
1
# coding=utf-8
2
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
""" Utilities for PyTorch Transformer XL model.
17
    Directly adapted from https://github.com/kimiyoung/transformer-xl.
18
"""
19

20

21
import torch
22
import torch.nn as nn
23
import torch.nn.functional as F
24

25

26
# CUDA_MAJOR = int(torch.version.cuda.split('.')[0])
27
# CUDA_MINOR = int(torch.version.cuda.split('.')[1])
28

29

30
class ProjectedAdaptiveLogSoftmax(nn.Module):
31
    def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, keep_order=False):
32
        super().__init__()
33

34
        self.n_token = n_token
35
        self.d_embed = d_embed
36
        self.d_proj = d_proj
37

38
        self.cutoffs = cutoffs + [n_token]
39
        self.cutoff_ends = [0] + self.cutoffs
40
        self.div_val = div_val
41

42
        self.shortlist_size = self.cutoffs[0]
43
        self.n_clusters = len(self.cutoffs) - 1
44
        self.head_size = self.shortlist_size + self.n_clusters
45

46
        if self.n_clusters > 0:
47
            self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))
48
            self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
49

50
        self.out_layers = nn.ModuleList()
51
        self.out_projs = nn.ParameterList()
52

53
        if div_val == 1:
54
            for i in range(len(self.cutoffs)):
55
                if d_proj != d_embed:
56
                    self.out_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed)))
57
                else:
58
                    self.out_projs.append(None)
59

60
            self.out_layers.append(nn.Linear(d_embed, n_token))
61
        else:
62
            for i in range(len(self.cutoffs)):
63
                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
64
                d_emb_i = d_embed // (div_val ** i)
65

66
                self.out_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i)))
67

68
                self.out_layers.append(nn.Linear(d_emb_i, r_idx - l_idx))
69

70
        self.keep_order = keep_order
71

72
    def _compute_logit(self, hidden, weight, bias, proj):
73
        if proj is None:
74
            logit = F.linear(hidden, weight, bias=bias)
75
        else:
76
            # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1:
77
            proj_hid = F.linear(hidden, proj.t().contiguous())
78
            logit = F.linear(proj_hid, weight, bias=bias)
79
            # else:
80
            #     logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
81
            #     if bias is not None:
82
            #         logit = logit + bias
83

84
        return logit
85

86
    def forward(self, hidden, labels=None, keep_order=False):
87
        """
88
            Params:
89
                hidden :: [len*bsz x d_proj]
90
                labels :: [len*bsz]
91
            Return:
92
                if labels is None:
93
                    out :: [len*bsz x n_tokens] log probabilities of tokens over the vocabulary
94
                else:
95
                    out :: [(len-1)*bsz] Negative log likelihood
96
            We could replace this implementation by the native PyTorch one
97
            if their's had an option to set bias on all clusters in the native one.
98
            here: https://github.com/pytorch/pytorch/blob/dbe6a7a9ff1a364a8706bf5df58a1ca96d2fd9da/torch/nn/modules/adaptive.py#L138
99
        """
100

101
        if labels is not None:
102
            # Shift so that tokens < n predict n
103
            hidden = hidden[..., :-1, :].contiguous()
104
            labels = labels[..., 1:].contiguous()
105
            hidden = hidden.view(-1, hidden.size(-1))
106
            labels = labels.view(-1)
107
            if hidden.size(0) != labels.size(0):
108
                raise RuntimeError("Input and labels should have the same size " "in the batch dimension.")
109
        else:
110
            hidden = hidden.view(-1, hidden.size(-1))
111

112
        if self.n_clusters == 0:
113
            logit = self._compute_logit(hidden, self.out_layers[0].weight, self.out_layers[0].bias, self.out_projs[0])
114
            if labels is not None:
115
                out = -F.log_softmax(logit, dim=-1).gather(1, labels.unsqueeze(1)).squeeze(1)
116
            else:
117
                out = F.log_softmax(logit, dim=-1)
118
        else:
119
            # construct weights and biases
120
            weights, biases = [], []
121
            for i in range(len(self.cutoffs)):
122
                if self.div_val == 1:
123
                    l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
124
                    weight_i = self.out_layers[0].weight[l_idx:r_idx]
125
                    bias_i = self.out_layers[0].bias[l_idx:r_idx]
126
                else:
127
                    weight_i = self.out_layers[i].weight
128
                    bias_i = self.out_layers[i].bias
129

130
                if i == 0:
131
                    weight_i = torch.cat([weight_i, self.cluster_weight], dim=0)
132
                    bias_i = torch.cat([bias_i, self.cluster_bias], dim=0)
133

134
                weights.append(weight_i)
135
                biases.append(bias_i)
136

137
            head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]
138

139
            head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
140
            head_logprob = F.log_softmax(head_logit, dim=1)
141

142
            if labels is None:
143
                out = hidden.new_empty((head_logit.size(0), self.n_token))
144
            else:
145
                out = torch.zeros_like(labels, dtype=hidden.dtype, device=hidden.device)
146

147
            offset = 0
148
            cutoff_values = [0] + self.cutoffs
149
            for i in range(len(cutoff_values) - 1):
150
                l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
151

152
                if labels is not None:
153
                    mask_i = (labels >= l_idx) & (labels < r_idx)
154
                    indices_i = mask_i.nonzero().squeeze()
155

156
                    if indices_i.numel() == 0:
157
                        continue
158

159
                    target_i = labels.index_select(0, indices_i) - l_idx
160
                    head_logprob_i = head_logprob.index_select(0, indices_i)
161
                    hidden_i = hidden.index_select(0, indices_i)
162
                else:
163
                    hidden_i = hidden
164

165
                if i == 0:
166
                    if labels is not None:
167
                        logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1)
168
                    else:
169
                        out[:, : self.cutoffs[0]] = head_logprob[:, : self.cutoffs[0]]
170
                else:
171
                    weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
172

173
                    tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
174
                    tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
175
                    cluster_prob_idx = self.cutoffs[0] + i - 1  # No probability for the head cluster
176
                    if labels is not None:
177
                        logprob_i = head_logprob_i[:, cluster_prob_idx] + tail_logprob_i.gather(
178
                            1, target_i[:, None]
179
                        ).squeeze(1)
180
                    else:
181
                        logprob_i = head_logprob[:, cluster_prob_idx, None] + tail_logprob_i
182
                        out[:, l_idx:r_idx] = logprob_i
183

184
                if labels is not None:
185
                    if (hasattr(self, "keep_order") and self.keep_order) or keep_order:
186
                        out.index_copy_(0, indices_i, -logprob_i)
187
                    else:
188
                        out[offset : offset + logprob_i.size(0)].copy_(-logprob_i)
189
                    offset += logprob_i.size(0)
190

191
        return out
192

193
    def log_prob(self, hidden):
194
        r""" Computes log probabilities for all :math:`n\_classes`
195
        From: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/adaptive.py
196
        Args:
197
            hidden (Tensor): a minibatch of examples
198
        Returns:
199
            log-probabilities of for each class :math:`c`
200
            in range :math:`0 <= c <= n\_classes`, where :math:`n\_classes` is a
201
            parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor.
202
        Shape:
203
            - Input: :math:`(N, in\_features)`
204
            - Output: :math:`(N, n\_classes)`
205
        """
206
        if self.n_clusters == 0:
207
            logit = self._compute_logit(hidden, self.out_layers[0].weight, self.out_layers[0].bias, self.out_projs[0])
208
            return F.log_softmax(logit, dim=-1)
209
        else:
210
            # construct weights and biases
211
            weights, biases = [], []
212
            for i in range(len(self.cutoffs)):
213
                if self.div_val == 1:
214
                    l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
215
                    weight_i = self.out_layers[0].weight[l_idx:r_idx]
216
                    bias_i = self.out_layers[0].bias[l_idx:r_idx]
217
                else:
218
                    weight_i = self.out_layers[i].weight
219
                    bias_i = self.out_layers[i].bias
220

221
                if i == 0:
222
                    weight_i = torch.cat([weight_i, self.cluster_weight], dim=0)
223
                    bias_i = torch.cat([bias_i, self.cluster_bias], dim=0)
224

225
                weights.append(weight_i)
226
                biases.append(bias_i)
227

228
            head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]
229
            head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
230

231
            out = hidden.new_empty((head_logit.size(0), self.n_token))
232
            head_logprob = F.log_softmax(head_logit, dim=1)
233

234
            cutoff_values = [0] + self.cutoffs
235
            for i in range(len(cutoff_values) - 1):
236
                start_idx, stop_idx = cutoff_values[i], cutoff_values[i + 1]
237

238
                if i == 0:
239
                    out[:, : self.cutoffs[0]] = head_logprob[:, : self.cutoffs[0]]
240
                else:
241
                    weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
242

243
                    tail_logit_i = self._compute_logit(hidden, weight_i, bias_i, proj_i)
244
                    tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
245

246
                    logprob_i = head_logprob[:, -i] + tail_logprob_i
247
                    out[:, start_idx, stop_idx] = logprob_i
248

249
            return out
250

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.