google-research
306 строк · 9.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Semiring."""
17import abc18from typing import Generic, TypeVar19
20from lingvo import compat as tf21import tensorflow_probability as tfp22
23import utils24
25T = TypeVar('T')26TensorTuple = utils.TensorTuple27LogTensor = tuple[tf.Tensor]28DualTensor = tuple[tf.Tensor, tf.Tensor]29LogReverseKLTensor = tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]30
31
32class Semiring(abc.ABC, Generic[T]):33"""Abstract base class for a semiring.34
35A monoid is a set equipped with a binary associative operation and an identity
36element.
37
38A semiring is a set equipped with addition and multiplication such that:
391) Addition (+) is a commutative monoid with identity (0).
402) Multiplication (*) is a monoid with identity (1).
413) Multiplication distributes over addition from both sides.
424) The additive identity (0) is an annihilator for multiplication (*), i.e.
43multiplying any element with (0) results in (0).
44
45Concrete subclasses of Semiring need to implement seven different methods:
461) additive_identity
472) add
483) add_list
494) multiplicative_identity
505) multiply
516) multiply_list
527) convert_logits
53
54add and multiply are binary operations as are in the definition of a semiring.
55However, add_list and multiply_list are needed as well because there are often
56more efficient ways to implement addition and multiplication than to do it
57iteratively. convert_logits converts the given logits into the input values
58expected by the semiring.
59"""
60
61@abc.abstractmethod62def additive_identity(self,63shape = (1,),64dtype = tf.float32):65"""Returns additive identity of the specified shape and datatype."""66
67@abc.abstractmethod68def add(self, elem_1, elem_2):69"""Adds two elements."""70
71@abc.abstractmethod72def add_list(self, elems_list):73"""Adds a list of elements."""74
75@abc.abstractmethod76def multiplicative_identity(self,77shape = (1,),78dtype = tf.float32):79"""Returns multiplicative identity of the specified shape and datatype."""80
81@abc.abstractmethod82def multiply(self, elem_1, elem_2):83"""Multiplies two elements."""84
85@abc.abstractmethod86def multiply_list(self, elems_list):87"""Multiplies a list of elements."""88
89@abc.abstractmethod90def convert_logits(self, elem):91"""Converts logits into semiring inputs."""92
93
94class LogSemiring(Semiring[LogTensor]):95"""Log semiring.96
97Each element is of the form log(p) where p is a real number from [0, 1].
98
99Additive identity:
100(0) = neg inf.
101
102Addition:
103a (+) b = LogSumExp(a, b).
104
105Multiplicative identity:
106(1) = 0.
107
108Multiplication:
109a (*) b = a + b.
110
111Convert logits:
112log(p) -> log(p).
113
114Note that multiplication is implemented in a numerically stable manner.
115"""
116_LOGP = 0 # First argument in LogTensor117
118def additive_identity(self,119shape = (1,),120dtype = tf.float32):121del self122return (utils.logzero(shape=shape, dtype=dtype),)123
124def add(self, elem_1, elem_2):125del self126return tuple(map(utils.logsumexp_list, zip(elem_1, elem_2)))127
128def add_list(self, elems_list):129del self130return tuple(map(utils.logsumexp_list, zip(*elems_list)))131
132def multiplicative_identity(self,133shape = (1,),134dtype = tf.float32):135del self136return (tf.zeros(shape=shape, dtype=dtype),)137
138def multiply(self, elem_1, elem_2):139return (utils.safe_result(elem_1[self._LOGP] + elem_2[self._LOGP]),)140
141def multiply_list(self, elems_list):142elems_list = [e[self._LOGP] for e in elems_list]143return (utils.safe_result(tf.add_n(elems_list)),)144
145def convert_logits(self, elem):146del self147return elem148
149
150class LogEntropySemiring(Semiring[DualTensor]):151"""Log Entropy semiring.152
153Each element is of the form <log(p), log(-plog(q))> where p and q are real
154numbers from [0, 1]. Addition and multiplication follow the dual number system
155but with a log morphism applied on both arguments.
156https://en.wikipedia.org/wiki/Dual_number
157
158Additive identity:
159(0) = <neg inf, neg inf>.
160
161Addition:
162<a, b> (+) <c, d> = <LogSumExp(a, c), LogSumExp(b, d)>.
163
164Multiplicative identity:
165(1) = <0, neg inf>.
166
167Multiplication:
168<a, b> (*) <c, d> = <a + c, LogSumExp(a + d, b + c)>.
169
170Convert logits:
171<log(p), log(q)> -> <log(p), log(-plog(q))>.
172
173Note that multiplication is implemented in a numerically stable manner.
174"""
175_LOGP = 0 # First argument in DualTensor176_LOGMINUSPLOGQ = 1 # Second argument in DualTensor177
178def additive_identity(self,179shape = (1,),180dtype = tf.float32):181del self182neg_inf = utils.logzero(shape=shape, dtype=dtype)183return (neg_inf, neg_inf)184
185def add(self, elem_1, elem_2):186del self187return tuple(map(utils.logsumexp_list, zip(elem_1, elem_2)))188
189def add_list(self, elems_list):190del self191return tuple(map(utils.logsumexp_list, zip(*elems_list)))192
193def multiplicative_identity(194self,195shape = (1,),196dtype = tf.float32):197del self198zero = tf.zeros(shape=shape, dtype=dtype)199neg_inf = utils.logzero(shape=shape, dtype=dtype)200return (zero, neg_inf)201
202def multiply(self, elem_1, elem_2):203logp = utils.safe_result(elem_1[self._LOGP] + elem_2[self._LOGP])204logminusplogq = utils.logcrossmultiply(elem_1[self._LOGP],205elem_1[self._LOGMINUSPLOGQ],206elem_2[self._LOGP],207elem_2[self._LOGMINUSPLOGQ])208return (logp, logminusplogq)209
210def multiply_list(self, elems_list):211# Compute the result iteratively.212elems = tuple(map(tf.stack, zip(*elems_list)))213elems = tfp.math.scan_associative(self.multiply, elems)214return tuple(e[-1] for e in elems)215
216def convert_logits(self, elem):217del self218logp, logq = elem219logminusplogq = utils.logminus(logp, logq)220return (logp, logminusplogq)221
222
223class LogReverseKLSemiring(Semiring[LogReverseKLTensor]):224"""Log Reverse-KL semiring.225
226Each element is of the form <log(p), log(q), log(-qlog(q)), log(-qlog(p))>
227where p and q are real numbers from [0, 1].
228
229Additive identity:
230(0) = <neg inf, neg inf, neg inf, neg inf>.
231
232Addition:
233<a, b, c, d> (+) <e, f, g, h> = <LogSumExp(a, e), LogSumExp(b, f),
234LogSumExp(c, g), LogSumExp(d, h)>.
235
236Multiplicative identity:
237(1) = <0, 0, neg inf, neg inf>.
238
239Multiplication:
240<a, b, c, d> (*) <e, f, g, h> = <a + e, b + f, LogSumExp(b + g, c + f),
241LogSumExp(b + h, d + f)>.
242
243Convert logits:
244<log(p), log(q)> -> <log(p), log(q), log(-qlog(q)), log(-qlog(p))>.
245
246Note that multiplication is implemented in a numerically stable manner.
247"""
248_LOGP = 0 # First argument in LogReverseKLTensor249_LOGQ = 1 # Second argument in LogReverseKLTensor250_LOGMINUSQLOGQ = 2 # Third argument in LogReverseKLTensor251_LOGMINUSQLOGP = 3 # Fourth argument in LogReverseKLTensor252
253def additive_identity(254self,255shape = (1,),256dtype = tf.float32):257del self258neg_inf = utils.logzero(shape=shape, dtype=dtype)259return (neg_inf, neg_inf, neg_inf, neg_inf)260
261def add(self, elem_1,262elem_2):263del self264return tuple(map(utils.logsumexp_list, zip(elem_1, elem_2)))265
266def add_list(self,267elems_list):268del self269return tuple(map(utils.logsumexp_list, zip(*elems_list)))270
271def multiplicative_identity(272self,273shape = (1,),274dtype = tf.float32):275del self276zero = tf.zeros(shape=shape, dtype=dtype)277neg_inf = utils.logzero(shape=shape, dtype=dtype)278return (zero, zero, neg_inf, neg_inf)279
280def multiply(self, elem_1,281elem_2):282logp = utils.safe_result(elem_1[self._LOGP] + elem_2[self._LOGP])283logq = utils.safe_result(elem_1[self._LOGQ] + elem_2[self._LOGQ])284logminusqlogq = utils.logcrossmultiply(elem_1[self._LOGQ],285elem_1[self._LOGMINUSQLOGQ],286elem_2[self._LOGQ],287elem_2[self._LOGMINUSQLOGQ])288logminusqlogp = utils.logcrossmultiply(elem_1[self._LOGQ],289elem_1[self._LOGMINUSQLOGP],290elem_2[self._LOGQ],291elem_2[self._LOGMINUSQLOGP])292return (logp, logq, logminusqlogq, logminusqlogp)293
294def multiply_list(self,295elems_list):296# Compute the result iteratively.297elems = tuple(map(tf.stack, zip(*elems_list)))298elems = tfp.math.scan_associative(self.multiply, elems)299return tuple(e[-1] for e in elems)300
301def convert_logits(self, elem):302del self303logp, logq = elem304logminusqlogq = utils.logminus(logq, logq)305logminusqlogp = utils.logminus(logq, logp)306return (logp, logq, logminusqlogq, logminusqlogp)307