google-research

Форк
0
306 строк · 9.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Semiring."""
17
import abc
18
from typing import Generic, TypeVar
19

20
from lingvo import compat as tf
21
import tensorflow_probability as tfp
22

23
import utils
24

25
T = TypeVar('T')
26
TensorTuple = utils.TensorTuple
27
LogTensor = tuple[tf.Tensor]
28
DualTensor = tuple[tf.Tensor, tf.Tensor]
29
LogReverseKLTensor = tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]
30

31

32
class Semiring(abc.ABC, Generic[T]):
33
  """Abstract base class for a semiring.
34

35
  A monoid is a set equipped with a binary associative operation and an identity
36
  element.
37

38
  A semiring is a set equipped with addition and multiplication such that:
39
  1) Addition (+) is a commutative monoid with identity (0).
40
  2) Multiplication (*) is a monoid with identity (1).
41
  3) Multiplication distributes over addition from both sides.
42
  4) The additive identity (0) is an annihilator for multiplication (*), i.e.
43
  multiplying any element with (0) results in (0).
44

45
  Concrete subclasses of Semiring need to implement seven different methods:
46
  1) additive_identity
47
  2) add
48
  3) add_list
49
  4) multiplicative_identity
50
  5) multiply
51
  6) multiply_list
52
  7) convert_logits
53

54
  add and multiply are binary operations as are in the definition of a semiring.
55
  However, add_list and multiply_list are needed as well because there are often
56
  more efficient ways to implement addition and multiplication than to do it
57
  iteratively. convert_logits converts the given logits into the input values
58
  expected by the semiring.
59
  """
60

61
  @abc.abstractmethod
62
  def additive_identity(self,
63
                        shape = (1,),
64
                        dtype = tf.float32):
65
    """Returns additive identity of the specified shape and datatype."""
66

67
  @abc.abstractmethod
68
  def add(self, elem_1, elem_2):
69
    """Adds two elements."""
70

71
  @abc.abstractmethod
72
  def add_list(self, elems_list):
73
    """Adds a list of elements."""
74

75
  @abc.abstractmethod
76
  def multiplicative_identity(self,
77
                              shape = (1,),
78
                              dtype = tf.float32):
79
    """Returns multiplicative identity of the specified shape and datatype."""
80

81
  @abc.abstractmethod
82
  def multiply(self, elem_1, elem_2):
83
    """Multiplies two elements."""
84

85
  @abc.abstractmethod
86
  def multiply_list(self, elems_list):
87
    """Multiplies a list of elements."""
88

89
  @abc.abstractmethod
90
  def convert_logits(self, elem):
91
    """Converts logits into semiring inputs."""
92

93

94
class LogSemiring(Semiring[LogTensor]):
95
  """Log semiring.
96

97
  Each element is of the form log(p) where p is a real number from [0, 1].
98

99
  Additive identity:
100
  (0) = neg inf.
101

102
  Addition:
103
  a (+) b = LogSumExp(a, b).
104

105
  Multiplicative identity:
106
  (1) = 0.
107

108
  Multiplication:
109
  a (*) b = a + b.
110

111
  Convert logits:
112
  log(p) -> log(p).
113

114
  Note that multiplication is implemented in a numerically stable manner.
115
  """
116
  _LOGP = 0  # First argument in LogTensor
117

118
  def additive_identity(self,
119
                        shape = (1,),
120
                        dtype = tf.float32):
121
    del self
122
    return (utils.logzero(shape=shape, dtype=dtype),)
123

124
  def add(self, elem_1, elem_2):
125
    del self
126
    return tuple(map(utils.logsumexp_list, zip(elem_1, elem_2)))
127

128
  def add_list(self, elems_list):
129
    del self
130
    return tuple(map(utils.logsumexp_list, zip(*elems_list)))
131

132
  def multiplicative_identity(self,
133
                              shape = (1,),
134
                              dtype = tf.float32):
135
    del self
136
    return (tf.zeros(shape=shape, dtype=dtype),)
137

138
  def multiply(self, elem_1, elem_2):
139
    return (utils.safe_result(elem_1[self._LOGP] + elem_2[self._LOGP]),)
140

141
  def multiply_list(self, elems_list):
142
    elems_list = [e[self._LOGP] for e in elems_list]
143
    return (utils.safe_result(tf.add_n(elems_list)),)
144

145
  def convert_logits(self, elem):
146
    del self
147
    return elem
148

149

150
class LogEntropySemiring(Semiring[DualTensor]):
151
  """Log Entropy semiring.
152

153
  Each element is of the form <log(p), log(-plog(q))> where p and q are real
154
  numbers from [0, 1]. Addition and multiplication follow the dual number system
155
  but with a log morphism applied on both arguments.
156
  https://en.wikipedia.org/wiki/Dual_number
157

158
  Additive identity:
159
  (0) = <neg inf, neg inf>.
160

161
  Addition:
162
  <a, b> (+) <c, d> = <LogSumExp(a, c), LogSumExp(b, d)>.
163

164
  Multiplicative identity:
165
  (1) = <0, neg inf>.
166

167
  Multiplication:
168
  <a, b> (*) <c, d> = <a + c, LogSumExp(a + d, b + c)>.
169

170
  Convert logits:
171
  <log(p), log(q)> -> <log(p), log(-plog(q))>.
172

173
  Note that multiplication is implemented in a numerically stable manner.
174
  """
175
  _LOGP = 0  # First argument in DualTensor
176
  _LOGMINUSPLOGQ = 1  # Second argument in DualTensor
177

178
  def additive_identity(self,
179
                        shape = (1,),
180
                        dtype = tf.float32):
181
    del self
182
    neg_inf = utils.logzero(shape=shape, dtype=dtype)
183
    return (neg_inf, neg_inf)
184

185
  def add(self, elem_1, elem_2):
186
    del self
187
    return tuple(map(utils.logsumexp_list, zip(elem_1, elem_2)))
188

189
  def add_list(self, elems_list):
190
    del self
191
    return tuple(map(utils.logsumexp_list, zip(*elems_list)))
192

193
  def multiplicative_identity(
194
      self,
195
      shape = (1,),
196
      dtype = tf.float32):
197
    del self
198
    zero = tf.zeros(shape=shape, dtype=dtype)
199
    neg_inf = utils.logzero(shape=shape, dtype=dtype)
200
    return (zero, neg_inf)
201

202
  def multiply(self, elem_1, elem_2):
203
    logp = utils.safe_result(elem_1[self._LOGP] + elem_2[self._LOGP])
204
    logminusplogq = utils.logcrossmultiply(elem_1[self._LOGP],
205
                                           elem_1[self._LOGMINUSPLOGQ],
206
                                           elem_2[self._LOGP],
207
                                           elem_2[self._LOGMINUSPLOGQ])
208
    return (logp, logminusplogq)
209

210
  def multiply_list(self, elems_list):
211
    # Compute the result iteratively.
212
    elems = tuple(map(tf.stack, zip(*elems_list)))
213
    elems = tfp.math.scan_associative(self.multiply, elems)
214
    return tuple(e[-1] for e in elems)
215

216
  def convert_logits(self, elem):
217
    del self
218
    logp, logq = elem
219
    logminusplogq = utils.logminus(logp, logq)
220
    return (logp, logminusplogq)
221

222

223
class LogReverseKLSemiring(Semiring[LogReverseKLTensor]):
224
  """Log Reverse-KL semiring.
225

226
  Each element is of the form <log(p), log(q), log(-qlog(q)), log(-qlog(p))>
227
  where p and q are real numbers from [0, 1].
228

229
  Additive identity:
230
  (0) = <neg inf, neg inf, neg inf, neg inf>.
231

232
  Addition:
233
  <a, b, c, d> (+) <e, f, g, h> = <LogSumExp(a, e), LogSumExp(b, f),
234
                                   LogSumExp(c, g), LogSumExp(d, h)>.
235

236
  Multiplicative identity:
237
  (1) = <0, 0, neg inf, neg inf>.
238

239
  Multiplication:
240
  <a, b, c, d> (*) <e, f, g, h> = <a + e, b + f, LogSumExp(b + g, c + f),
241
                                   LogSumExp(b + h, d + f)>.
242

243
  Convert logits:
244
  <log(p), log(q)> -> <log(p), log(q), log(-qlog(q)), log(-qlog(p))>.
245

246
  Note that multiplication is implemented in a numerically stable manner.
247
  """
248
  _LOGP = 0  # First argument in LogReverseKLTensor
249
  _LOGQ = 1  # Second argument in LogReverseKLTensor
250
  _LOGMINUSQLOGQ = 2  # Third argument in LogReverseKLTensor
251
  _LOGMINUSQLOGP = 3  # Fourth argument in LogReverseKLTensor
252

253
  def additive_identity(
254
      self,
255
      shape = (1,),
256
      dtype = tf.float32):
257
    del self
258
    neg_inf = utils.logzero(shape=shape, dtype=dtype)
259
    return (neg_inf, neg_inf, neg_inf, neg_inf)
260

261
  def add(self, elem_1,
262
          elem_2):
263
    del self
264
    return tuple(map(utils.logsumexp_list, zip(elem_1, elem_2)))
265

266
  def add_list(self,
267
               elems_list):
268
    del self
269
    return tuple(map(utils.logsumexp_list, zip(*elems_list)))
270

271
  def multiplicative_identity(
272
      self,
273
      shape = (1,),
274
      dtype = tf.float32):
275
    del self
276
    zero = tf.zeros(shape=shape, dtype=dtype)
277
    neg_inf = utils.logzero(shape=shape, dtype=dtype)
278
    return (zero, zero, neg_inf, neg_inf)
279

280
  def multiply(self, elem_1,
281
               elem_2):
282
    logp = utils.safe_result(elem_1[self._LOGP] + elem_2[self._LOGP])
283
    logq = utils.safe_result(elem_1[self._LOGQ] + elem_2[self._LOGQ])
284
    logminusqlogq = utils.logcrossmultiply(elem_1[self._LOGQ],
285
                                           elem_1[self._LOGMINUSQLOGQ],
286
                                           elem_2[self._LOGQ],
287
                                           elem_2[self._LOGMINUSQLOGQ])
288
    logminusqlogp = utils.logcrossmultiply(elem_1[self._LOGQ],
289
                                           elem_1[self._LOGMINUSQLOGP],
290
                                           elem_2[self._LOGQ],
291
                                           elem_2[self._LOGMINUSQLOGP])
292
    return (logp, logq, logminusqlogq, logminusqlogp)
293

294
  def multiply_list(self,
295
                    elems_list):
296
    # Compute the result iteratively.
297
    elems = tuple(map(tf.stack, zip(*elems_list)))
298
    elems = tfp.math.scan_associative(self.multiply, elems)
299
    return tuple(e[-1] for e in elems)
300

301
  def convert_logits(self, elem):
302
    del self
303
    logp, logq = elem
304
    logminusqlogq = utils.logminus(logq, logq)
305
    logminusqlogp = utils.logminus(logq, logp)
306
    return (logp, logq, logminusqlogq, logminusqlogp)
307

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.