CSS-LM
249 строк · 10.4 Кб
1# coding=utf-8
2# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
3# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16""" Utilities for PyTorch Transformer XL model.
17Directly adapted from https://github.com/kimiyoung/transformer-xl.
18"""
19
20
21import torch22import torch.nn as nn23import torch.nn.functional as F24
25
26# CUDA_MAJOR = int(torch.version.cuda.split('.')[0])
27# CUDA_MINOR = int(torch.version.cuda.split('.')[1])
28
29
30class ProjectedAdaptiveLogSoftmax(nn.Module):31def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, keep_order=False):32super().__init__()33
34self.n_token = n_token35self.d_embed = d_embed36self.d_proj = d_proj37
38self.cutoffs = cutoffs + [n_token]39self.cutoff_ends = [0] + self.cutoffs40self.div_val = div_val41
42self.shortlist_size = self.cutoffs[0]43self.n_clusters = len(self.cutoffs) - 144self.head_size = self.shortlist_size + self.n_clusters45
46if self.n_clusters > 0:47self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))48self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))49
50self.out_layers = nn.ModuleList()51self.out_projs = nn.ParameterList()52
53if div_val == 1:54for i in range(len(self.cutoffs)):55if d_proj != d_embed:56self.out_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed)))57else:58self.out_projs.append(None)59
60self.out_layers.append(nn.Linear(d_embed, n_token))61else:62for i in range(len(self.cutoffs)):63l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]64d_emb_i = d_embed // (div_val ** i)65
66self.out_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i)))67
68self.out_layers.append(nn.Linear(d_emb_i, r_idx - l_idx))69
70self.keep_order = keep_order71
72def _compute_logit(self, hidden, weight, bias, proj):73if proj is None:74logit = F.linear(hidden, weight, bias=bias)75else:76# if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1:77proj_hid = F.linear(hidden, proj.t().contiguous())78logit = F.linear(proj_hid, weight, bias=bias)79# else:80# logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))81# if bias is not None:82# logit = logit + bias83
84return logit85
86def forward(self, hidden, labels=None, keep_order=False):87"""88Params:
89hidden :: [len*bsz x d_proj]
90labels :: [len*bsz]
91Return:
92if labels is None:
93out :: [len*bsz x n_tokens] log probabilities of tokens over the vocabulary
94else:
95out :: [(len-1)*bsz] Negative log likelihood
96We could replace this implementation by the native PyTorch one
97if their's had an option to set bias on all clusters in the native one.
98here: https://github.com/pytorch/pytorch/blob/dbe6a7a9ff1a364a8706bf5df58a1ca96d2fd9da/torch/nn/modules/adaptive.py#L138
99"""
100
101if labels is not None:102# Shift so that tokens < n predict n103hidden = hidden[..., :-1, :].contiguous()104labels = labels[..., 1:].contiguous()105hidden = hidden.view(-1, hidden.size(-1))106labels = labels.view(-1)107if hidden.size(0) != labels.size(0):108raise RuntimeError("Input and labels should have the same size " "in the batch dimension.")109else:110hidden = hidden.view(-1, hidden.size(-1))111
112if self.n_clusters == 0:113logit = self._compute_logit(hidden, self.out_layers[0].weight, self.out_layers[0].bias, self.out_projs[0])114if labels is not None:115out = -F.log_softmax(logit, dim=-1).gather(1, labels.unsqueeze(1)).squeeze(1)116else:117out = F.log_softmax(logit, dim=-1)118else:119# construct weights and biases120weights, biases = [], []121for i in range(len(self.cutoffs)):122if self.div_val == 1:123l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]124weight_i = self.out_layers[0].weight[l_idx:r_idx]125bias_i = self.out_layers[0].bias[l_idx:r_idx]126else:127weight_i = self.out_layers[i].weight128bias_i = self.out_layers[i].bias129
130if i == 0:131weight_i = torch.cat([weight_i, self.cluster_weight], dim=0)132bias_i = torch.cat([bias_i, self.cluster_bias], dim=0)133
134weights.append(weight_i)135biases.append(bias_i)136
137head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]138
139head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)140head_logprob = F.log_softmax(head_logit, dim=1)141
142if labels is None:143out = hidden.new_empty((head_logit.size(0), self.n_token))144else:145out = torch.zeros_like(labels, dtype=hidden.dtype, device=hidden.device)146
147offset = 0148cutoff_values = [0] + self.cutoffs149for i in range(len(cutoff_values) - 1):150l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]151
152if labels is not None:153mask_i = (labels >= l_idx) & (labels < r_idx)154indices_i = mask_i.nonzero().squeeze()155
156if indices_i.numel() == 0:157continue158
159target_i = labels.index_select(0, indices_i) - l_idx160head_logprob_i = head_logprob.index_select(0, indices_i)161hidden_i = hidden.index_select(0, indices_i)162else:163hidden_i = hidden164
165if i == 0:166if labels is not None:167logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1)168else:169out[:, : self.cutoffs[0]] = head_logprob[:, : self.cutoffs[0]]170else:171weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]172
173tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)174tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)175cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster176if labels is not None:177logprob_i = head_logprob_i[:, cluster_prob_idx] + tail_logprob_i.gather(1781, target_i[:, None]179).squeeze(1)180else:181logprob_i = head_logprob[:, cluster_prob_idx, None] + tail_logprob_i182out[:, l_idx:r_idx] = logprob_i183
184if labels is not None:185if (hasattr(self, "keep_order") and self.keep_order) or keep_order:186out.index_copy_(0, indices_i, -logprob_i)187else:188out[offset : offset + logprob_i.size(0)].copy_(-logprob_i)189offset += logprob_i.size(0)190
191return out192
193def log_prob(self, hidden):194r""" Computes log probabilities for all :math:`n\_classes`195From: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/adaptive.py
196Args:
197hidden (Tensor): a minibatch of examples
198Returns:
199log-probabilities of for each class :math:`c`
200in range :math:`0 <= c <= n\_classes`, where :math:`n\_classes` is a
201parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor.
202Shape:
203- Input: :math:`(N, in\_features)`
204- Output: :math:`(N, n\_classes)`
205"""
206if self.n_clusters == 0:207logit = self._compute_logit(hidden, self.out_layers[0].weight, self.out_layers[0].bias, self.out_projs[0])208return F.log_softmax(logit, dim=-1)209else:210# construct weights and biases211weights, biases = [], []212for i in range(len(self.cutoffs)):213if self.div_val == 1:214l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]215weight_i = self.out_layers[0].weight[l_idx:r_idx]216bias_i = self.out_layers[0].bias[l_idx:r_idx]217else:218weight_i = self.out_layers[i].weight219bias_i = self.out_layers[i].bias220
221if i == 0:222weight_i = torch.cat([weight_i, self.cluster_weight], dim=0)223bias_i = torch.cat([bias_i, self.cluster_bias], dim=0)224
225weights.append(weight_i)226biases.append(bias_i)227
228head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]229head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)230
231out = hidden.new_empty((head_logit.size(0), self.n_token))232head_logprob = F.log_softmax(head_logit, dim=1)233
234cutoff_values = [0] + self.cutoffs235for i in range(len(cutoff_values) - 1):236start_idx, stop_idx = cutoff_values[i], cutoff_values[i + 1]237
238if i == 0:239out[:, : self.cutoffs[0]] = head_logprob[:, : self.cutoffs[0]]240else:241weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]242
243tail_logit_i = self._compute_logit(hidden, weight_i, bias_i, proj_i)244tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)245
246logprob_i = head_logprob[:, -i] + tail_logprob_i247out[:, start_idx, stop_idx] = logprob_i248
249return out250