CSS-LM

Форк
0
/
modeling_tf_transfo_xl_utilities.py 
178 строк · 7.5 Кб
1
# coding=utf-8
2
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
""" A TF 2.0 Adaptive Softmax for Transformer XL model.
17
"""
18

19

20
import tensorflow as tf
21

22
from .modeling_tf_utils import shape_list
23

24

25
class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
26
    def __init__(self, vocab_size, d_embed, d_proj, cutoffs, div_val=1, keep_order=False, **kwargs):
27
        super().__init__(**kwargs)
28

29
        self.vocab_size = vocab_size
30
        self.d_embed = d_embed
31
        self.d_proj = d_proj
32

33
        self.cutoffs = cutoffs + [vocab_size]
34
        self.cutoff_ends = [0] + self.cutoffs
35
        self.div_val = div_val
36

37
        self.shortlist_size = self.cutoffs[0]
38
        self.n_clusters = len(self.cutoffs) - 1
39
        self.head_size = self.shortlist_size + self.n_clusters
40
        self.keep_order = keep_order
41

42
        self.out_layers = []
43
        self.out_projs = []
44

45
    def build(self, input_shape):
46
        if self.n_clusters > 0:
47
            self.cluster_weight = self.add_weight(
48
                shape=(self.n_clusters, self.d_embed), initializer="zeros", trainable=True, name="cluster_weight"
49
            )
50
            self.cluster_bias = self.add_weight(
51
                shape=(self.n_clusters,), initializer="zeros", trainable=True, name="cluster_bias"
52
            )
53

54
        if self.div_val == 1:
55
            for i in range(len(self.cutoffs)):
56
                if self.d_proj != self.d_embed:
57
                    weight = self.add_weight(
58
                        shape=(self.d_embed, self.d_proj),
59
                        initializer="zeros",
60
                        trainable=True,
61
                        name="out_projs_._{}".format(i),
62
                    )
63
                    self.out_projs.append(weight)
64
                else:
65
                    self.out_projs.append(None)
66
                weight = self.add_weight(
67
                    shape=(self.vocab_size, self.d_embed,),
68
                    initializer="zeros",
69
                    trainable=True,
70
                    name="out_layers_._{}_._weight".format(i),
71
                )
72
                bias = self.add_weight(
73
                    shape=(self.vocab_size,),
74
                    initializer="zeros",
75
                    trainable=True,
76
                    name="out_layers_._{}_._bias".format(i),
77
                )
78
                self.out_layers.append((weight, bias))
79
        else:
80
            for i in range(len(self.cutoffs)):
81
                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
82
                d_emb_i = self.d_embed // (self.div_val ** i)
83

84
                weight = self.add_weight(
85
                    shape=(d_emb_i, self.d_proj), initializer="zeros", trainable=True, name="out_projs_._{}".format(i)
86
                )
87
                self.out_projs.append(weight)
88
                weight = self.add_weight(
89
                    shape=(r_idx - l_idx, d_emb_i,),
90
                    initializer="zeros",
91
                    trainable=True,
92
                    name="out_layers_._{}_._weight".format(i),
93
                )
94
                bias = self.add_weight(
95
                    shape=(r_idx - l_idx,),
96
                    initializer="zeros",
97
                    trainable=True,
98
                    name="out_layers_._{}_._bias".format(i),
99
                )
100
                self.out_layers.append((weight, bias))
101
        super().build(input_shape)
102

103
    @staticmethod
104
    def _logit(x, W, b, proj=None):
105
        y = x
106
        if proj is not None:
107
            y = tf.einsum("ibd,ed->ibe", y, proj)
108
        return tf.einsum("ibd,nd->ibn", y, W) + b
109

110
    @staticmethod
111
    def _gather_logprob(logprob, target):
112
        lp_size = shape_list(logprob)
113
        r = tf.range(lp_size[0])
114
        idx = tf.stack([r, target], 1)
115
        return tf.gather_nd(logprob, idx)
116

117
    def call(self, inputs, return_mean=True, training=False):
118
        hidden, target = inputs
119
        head_logprob = 0
120
        if self.n_clusters == 0:
121
            output = self._logit(hidden, self.out_layers[0][0], self.out_layers[0][1], self.out_projs[0])
122
            if target is not None:
123
                loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target, logits=output)
124
            out = tf.nn.log_softmax(output, axis=-1)
125
        else:
126
            hidden_sizes = shape_list(hidden)
127
            out = []
128
            loss = tf.zeros(hidden_sizes[:2], dtype=tf.float32)
129
            for i in range(len(self.cutoffs)):
130
                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
131
                if target is not None:
132
                    mask = (target >= l_idx) & (target < r_idx)
133
                    mask_idx = tf.where(mask)
134
                    cur_target = tf.boolean_mask(target, mask) - l_idx
135

136
                if self.div_val == 1:
137
                    cur_W = self.out_layers[0][0][l_idx:r_idx]
138
                    cur_b = self.out_layers[0][1][l_idx:r_idx]
139
                else:
140
                    cur_W = self.out_layers[i][0]
141
                    cur_b = self.out_layers[i][1]
142

143
                if i == 0:
144
                    cur_W = tf.concat([cur_W, self.cluster_weight], 0)
145
                    cur_b = tf.concat([cur_b, self.cluster_bias], 0)
146

147
                    head_logit = self._logit(hidden, cur_W, cur_b, self.out_projs[0])
148
                    head_logprob = tf.nn.log_softmax(head_logit)
149
                    out.append(head_logprob[..., : self.cutoffs[0]])
150
                    if target is not None:
151
                        cur_head_logprob = tf.boolean_mask(head_logprob, mask)
152
                        cur_logprob = self._gather_logprob(cur_head_logprob, cur_target)
153
                else:
154
                    tail_logit = self._logit(hidden, cur_W, cur_b, self.out_projs[i])
155
                    tail_logprob = tf.nn.log_softmax(tail_logit)
156
                    cluster_prob_idx = self.cutoffs[0] + i - 1  # No probability for the head cluster
157
                    logprob_i = head_logprob[..., cluster_prob_idx, None] + tail_logprob
158
                    out.append(logprob_i)
159
                    if target is not None:
160
                        cur_head_logprob = tf.boolean_mask(head_logprob, mask)
161
                        cur_tail_logprob = tf.boolean_mask(tail_logprob, mask)
162
                        cur_logprob = self._gather_logprob(cur_tail_logprob, cur_target)
163
                        cur_logprob += cur_head_logprob[:, self.cutoff_ends[1] + i - 1]
164
                if target is not None:
165
                    loss += tf.scatter_nd(mask_idx, -cur_logprob, tf.cast(shape_list(loss), dtype=tf.int64))
166
            out = tf.concat(out, axis=-1)
167

168
        if target is not None:
169
            if return_mean:
170
                loss = tf.reduce_mean(loss)
171
            # Add the training-time loss value to the layer using `self.add_loss()`.
172
            self.add_loss(loss)
173

174
            # Log the loss as a metric (we could log arbitrary metrics,
175
            # including different metrics for training and inference.
176
            self.add_metric(loss, name=self.name, aggregation="mean" if return_mean else "")
177

178
        return out
179

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.