CSS-LM
178 строк · 7.5 Кб
1# coding=utf-8
2# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
3# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16""" A TF 2.0 Adaptive Softmax for Transformer XL model.
17"""
18
19
20import tensorflow as tf
21
22from .modeling_tf_utils import shape_list
23
24
25class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
26def __init__(self, vocab_size, d_embed, d_proj, cutoffs, div_val=1, keep_order=False, **kwargs):
27super().__init__(**kwargs)
28
29self.vocab_size = vocab_size
30self.d_embed = d_embed
31self.d_proj = d_proj
32
33self.cutoffs = cutoffs + [vocab_size]
34self.cutoff_ends = [0] + self.cutoffs
35self.div_val = div_val
36
37self.shortlist_size = self.cutoffs[0]
38self.n_clusters = len(self.cutoffs) - 1
39self.head_size = self.shortlist_size + self.n_clusters
40self.keep_order = keep_order
41
42self.out_layers = []
43self.out_projs = []
44
45def build(self, input_shape):
46if self.n_clusters > 0:
47self.cluster_weight = self.add_weight(
48shape=(self.n_clusters, self.d_embed), initializer="zeros", trainable=True, name="cluster_weight"
49)
50self.cluster_bias = self.add_weight(
51shape=(self.n_clusters,), initializer="zeros", trainable=True, name="cluster_bias"
52)
53
54if self.div_val == 1:
55for i in range(len(self.cutoffs)):
56if self.d_proj != self.d_embed:
57weight = self.add_weight(
58shape=(self.d_embed, self.d_proj),
59initializer="zeros",
60trainable=True,
61name="out_projs_._{}".format(i),
62)
63self.out_projs.append(weight)
64else:
65self.out_projs.append(None)
66weight = self.add_weight(
67shape=(self.vocab_size, self.d_embed,),
68initializer="zeros",
69trainable=True,
70name="out_layers_._{}_._weight".format(i),
71)
72bias = self.add_weight(
73shape=(self.vocab_size,),
74initializer="zeros",
75trainable=True,
76name="out_layers_._{}_._bias".format(i),
77)
78self.out_layers.append((weight, bias))
79else:
80for i in range(len(self.cutoffs)):
81l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
82d_emb_i = self.d_embed // (self.div_val ** i)
83
84weight = self.add_weight(
85shape=(d_emb_i, self.d_proj), initializer="zeros", trainable=True, name="out_projs_._{}".format(i)
86)
87self.out_projs.append(weight)
88weight = self.add_weight(
89shape=(r_idx - l_idx, d_emb_i,),
90initializer="zeros",
91trainable=True,
92name="out_layers_._{}_._weight".format(i),
93)
94bias = self.add_weight(
95shape=(r_idx - l_idx,),
96initializer="zeros",
97trainable=True,
98name="out_layers_._{}_._bias".format(i),
99)
100self.out_layers.append((weight, bias))
101super().build(input_shape)
102
103@staticmethod
104def _logit(x, W, b, proj=None):
105y = x
106if proj is not None:
107y = tf.einsum("ibd,ed->ibe", y, proj)
108return tf.einsum("ibd,nd->ibn", y, W) + b
109
110@staticmethod
111def _gather_logprob(logprob, target):
112lp_size = shape_list(logprob)
113r = tf.range(lp_size[0])
114idx = tf.stack([r, target], 1)
115return tf.gather_nd(logprob, idx)
116
117def call(self, inputs, return_mean=True, training=False):
118hidden, target = inputs
119head_logprob = 0
120if self.n_clusters == 0:
121output = self._logit(hidden, self.out_layers[0][0], self.out_layers[0][1], self.out_projs[0])
122if target is not None:
123loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target, logits=output)
124out = tf.nn.log_softmax(output, axis=-1)
125else:
126hidden_sizes = shape_list(hidden)
127out = []
128loss = tf.zeros(hidden_sizes[:2], dtype=tf.float32)
129for i in range(len(self.cutoffs)):
130l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
131if target is not None:
132mask = (target >= l_idx) & (target < r_idx)
133mask_idx = tf.where(mask)
134cur_target = tf.boolean_mask(target, mask) - l_idx
135
136if self.div_val == 1:
137cur_W = self.out_layers[0][0][l_idx:r_idx]
138cur_b = self.out_layers[0][1][l_idx:r_idx]
139else:
140cur_W = self.out_layers[i][0]
141cur_b = self.out_layers[i][1]
142
143if i == 0:
144cur_W = tf.concat([cur_W, self.cluster_weight], 0)
145cur_b = tf.concat([cur_b, self.cluster_bias], 0)
146
147head_logit = self._logit(hidden, cur_W, cur_b, self.out_projs[0])
148head_logprob = tf.nn.log_softmax(head_logit)
149out.append(head_logprob[..., : self.cutoffs[0]])
150if target is not None:
151cur_head_logprob = tf.boolean_mask(head_logprob, mask)
152cur_logprob = self._gather_logprob(cur_head_logprob, cur_target)
153else:
154tail_logit = self._logit(hidden, cur_W, cur_b, self.out_projs[i])
155tail_logprob = tf.nn.log_softmax(tail_logit)
156cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster
157logprob_i = head_logprob[..., cluster_prob_idx, None] + tail_logprob
158out.append(logprob_i)
159if target is not None:
160cur_head_logprob = tf.boolean_mask(head_logprob, mask)
161cur_tail_logprob = tf.boolean_mask(tail_logprob, mask)
162cur_logprob = self._gather_logprob(cur_tail_logprob, cur_target)
163cur_logprob += cur_head_logprob[:, self.cutoff_ends[1] + i - 1]
164if target is not None:
165loss += tf.scatter_nd(mask_idx, -cur_logprob, tf.cast(shape_list(loss), dtype=tf.int64))
166out = tf.concat(out, axis=-1)
167
168if target is not None:
169if return_mean:
170loss = tf.reduce_mean(loss)
171# Add the training-time loss value to the layer using `self.add_loss()`.
172self.add_loss(loss)
173
174# Log the loss as a metric (we could log arbitrary metrics,
175# including different metrics for training and inference.
176self.add_metric(loss, name=self.name, aggregation="mean" if return_mean else "")
177
178return out
179