google-research
433 строки · 15.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Neural Speech Recognition Losses."""
17
18from lingvo import compat as tf19from lingvo.core import py_utils20import semiring21import utils22
23SeqLen = utils.SeqLen24
25
26def interleave_with_blank(x, blank,27axis):28"""Interleaves x with blanks along axis.29
30E.g. if input is AAA, we want the output to be bAbAbAb.
31
32Args:
33x: Tensor of shape [..., U, ...].
34blank: Tensor of same shape as x except U is now 1.
35axis: Axis of input that U corresponds to.
36
37Returns:
38output: Tensor of same shape as x except U is now 2*U+1.
39"""
40input_shape = tf.shape(x)41input_rank = tf.rank(x)42u = input_shape[axis]43
44# Create a blanks tensor of same shape as x.45blanks = tf.broadcast_to(blank, input_shape) # [..., U, ...]46
47# Interleave x with blanks.48interleaved_dims = input_shape + tf.one_hot(49axis, depth=input_rank, dtype=tf.int32) * u50interleaved = tf.reshape(51tf.stack([blanks, x], axis=axis + 1), interleaved_dims) # [..., 2*U, ...]52
53# Add an extra blank at the end.54interleaved = tf.concat([interleaved, blank], axis=axis) # [..., 2*U+1, ...]55
56return interleaved57
58
59def ctc(input_logits, output_labels,60input_seq_len, output_seq_len):61"""CTC loss.62
63B: Batch size.
64T: Input sequence dimension.
65U: Output sequence dimension.
66V: Vocabulary size.
67
68Alex Graves's CTC paper can be found at:
69https://www.cs.toronto.edu/~graves/icml_2006.pdf
70
71The following Distill article is also a helpful reference:
72https://distill.pub/2017/ctc/
73
74Args:
75input_logits: Logits for input sequence of shape [B, T, V]. We assume the
760th token in the vocabulary represents the blank.
77output_labels: Labels for output sequence of shape [B, U].
78input_seq_len: Sequence lengths for input sequence of shape [B].
79output_seq_len: Sequence lengths for output sequence of shape [B].
80
81Returns:
82CTC loss, which is a tf.Tensor of shape B.
83"""
84log_sum = ctc_semiring(85sr=semiring.LogSemiring(),86sr_inputs=(input_logits,),87output_labels=output_labels,88input_seq_len=input_seq_len,89output_seq_len=output_seq_len)[0]90return -log_sum91
92
93def rnnt(s1_logits,94s2_logits,95s1_seq_len,96s2_seq_len):97"""RNN-T loss for two sequences.98
99B: Batch size.
100D: Loop skewing dimension, i.e. the sum of s1 and s2.
101S1: Sequence 1 dimension (canonically, the input).
102S2: Sequence 2 dimension (canonically, the output).
103
104At each step, we consume/produce from each of the two sequences until all the
105tokens have been used. We abide by the convention that the last token is
106always from sequence 1. The RNN-T loss is a dynamic programming algorithm that
107sums the probability of every possible such sequence.
108
109The dynamic programming equation is:
110alpha[s1, s2] = alpha[s1-1, s2] * s1_logits[s1-1, s2] +
111alpha[s1, s2-1] * s2_logits[s1, s2-1].
112
113The boundary condition is:
114loss = alpha[S1, S2] * s1_logits[S1, S2].
115
116In practice, we work in the log space for numerical stability. Instead of
117two nested for loops, we do loop skewing by iterating through the diagonal
118line, D.
119
120Loop skewing for RNN-T is discussed here:
121T. Bagby, K. Rao and K. C. Sim, "Efficient Implementation of Recurrent
122Neural Network Transducer in Tensorflow," 2018 IEEE Spoken Language
123Technology Workshop (SLT), Athens, Greece, 2018, pp. 506-512.
124
125The following wikipedia article is also a helpful reference:
126https://en.wikipedia.org/wiki/Polytope_model
127
128Args:
129s1_logits: Logits for sequence 1 of shape [B, S1, S2]. We always end with a
130token from sequence 1.
131s2_logits: Logits for sequence 2 of shape [B, S1, S2].
132s1_seq_len: Sequence lengths for sequence 1 of shape [B].
133s2_seq_len: Sequence lengths for sequence 2 of shape [B].
134
135Returns:
136RNN-T loss, which is a tf.Tensor of shape B.
137"""
138log_sum = rnnt_semiring(139sr=semiring.LogSemiring(),140s1_inputs=(s1_logits,),141s2_inputs=(s2_logits,),142s1_seq_len=s1_seq_len,143s2_seq_len=s2_seq_len)[0]144return -log_sum145
146
147def ctc_semiring(sr,148sr_inputs, output_labels,149input_seq_len,150output_seq_len):151"""CTC loss for an arbitrary semiring.152
153The CTC dynamic programming graph stays the same, but the addition and
154multiplication operations are now given by the semiring.
155
156B: Batch size.
157T: Input sequence dimension.
158U: Output sequence dimension.
159V: Vocabulary size.
160
161Alex Graves's CTC paper can be found at:
162https://www.cs.toronto.edu/~graves/icml_2006.pdf
163
164The following Distill article is also a helpful reference:
165https://distill.pub/2017/ctc/
166
167Args:
168sr: Semiring object where each state is a tuple of tf.Tensors of shape [B,
169T, V].
170sr_inputs: Input to the CTC graph.
171output_labels: Labels for output sequence of shape [B, U].
172input_seq_len: Sequence lengths for input sequence of shape [B].
173output_seq_len: Sequence lengths for output sequence of shape [B].
174
175Returns:
176Output of the CTC graph, which is a tuple of tf.Tensors of shape [B].
177"""
178tf.debugging.assert_shapes([(state, ['B', 'T', 'V']) for state in sr_inputs])179tf.debugging.assert_shapes([180(output_labels, ['B', 'U']),181(input_seq_len, ['B']),182(output_seq_len, ['B']),183])184
185# Convert inputs to tensors.186sr_inputs = tuple(tf.convert_to_tensor(i) for i in sr_inputs) # [B, T, V]187output_labels = tf.convert_to_tensor(output_labels) # [B, U]188input_seq_len = tf.convert_to_tensor(input_seq_len) # [B]189output_seq_len = tf.convert_to_tensor(output_seq_len) # [B]190
191b, t, v = py_utils.GetShape(sr_inputs[0])192u = py_utils.GetShape(output_labels)[1]193dtype = sr_inputs[0].dtype194
195# Create a bitmask for the labels that returns True for the start of every196# contiguous segment of repeating characters.197# E.g. For AABBB, we have TFTFF. After padding with blanks (which are filled198# with False), i.e. bAbAbBbBbBb, we have FTFFFTFFFFF.199is_label_distinct = tf.not_equal(output_labels[:, :-1],200output_labels[:, 1:]) # [B, U-1]201is_label_distinct = tf.pad(202is_label_distinct, [[0, 0], [1, 0]], constant_values=True) # [B, U]203is_label_distinct = interleave_with_blank(204is_label_distinct, tf.zeros([b, 1], dtype=tf.bool), axis=1) # [B, 2*U+1]205
206# Create CTC tables of shape [B, 2*U+1, T].207ctc_state_tables = []208onehot_labels = tf.one_hot(output_labels, depth=v, dtype=dtype) # [B, U, V]209for state in sr_inputs:210ctc_state_table = tf.einsum('buv, btv -> but', onehot_labels,211state) # [B, U, T]212blank = tf.transpose(state[:, :, :1], [0, 2, 1]) # [B, 1, T]213ctc_state_table = interleave_with_blank(214ctc_state_table, blank, axis=1) # [B, 2*U+1, T]215ctc_state_tables.append(ctc_state_table)216ctc_state_tables = sr.convert_logits(tuple(ctc_state_tables)) # [B, 2*U+1, T]217
218# Mask out invalid starting states, i.e. all but the first two.219start_mask = tf.concat(220[tf.ones([2], dtype=tf.bool),221tf.zeros([2 * u - 1], dtype=tf.bool)],222axis=0)[:, tf.newaxis] # [2*U+1, 1]223start_mask = tf.pad(224start_mask, [[0, 0], [0, t - 1]], constant_values=True) # [2*U+1, T]225start_mask = tf.tile(start_mask[tf.newaxis], [b, 1, 1]) # [B, 2*U+1, T]226
227additive_identity = sr.additive_identity(228shape=(b, 2 * u + 1, t), dtype=dtype) # [B, 2*U+1, T]229ctc_state_tables = [230tf.where(start_mask, cst, ai)231for cst, ai in zip(ctc_state_tables, additive_identity)232] # [B, 2*U+1, T]233
234# Iterate through the CTC tables.235ctc_state_tables = tuple(236tf.transpose(cst, [2, 0, 1]) for cst in ctc_state_tables) # [T, B, 2*U+1]237
238def _generate_transitions(acc, ai):239"""Generates CTC state transitions."""240plus_zero = acc # [B, 2*U+1]241plus_one = tf.pad(242acc[:, :-1], [[0, 0], [1, 0]], constant_values=ai[0, 0]) # [B, 2*U+1]243plus_two = tf.pad(244acc[:, :-2], [[0, 0], [2, 0]], constant_values=ai[0, 0]) # [B, 2*U+1]245plus_two = tf.where(is_label_distinct, plus_two, ai) # [B, 2*U+1]246return [plus_zero, plus_one, plus_two]247
248def _step(acc, x):249additive_identity = sr.additive_identity(250shape=(b, 2 * u + 1), dtype=dtype) # [B, 2*U+1]251path_sum = tuple(252_generate_transitions(acc_i, ai)253for acc_i, ai in zip(acc, additive_identity)) # [B, 2*U+1]254path_sum = sr.add_list(utils.tuple_to_list(path_sum)) # [B, 2*U+1]255new_acc = sr.multiply(path_sum, x) # [B, 2*U+1]256return new_acc257
258ctc_state_tables = tf.scan(_step, ctc_state_tables) # [T, B, 2*U+1]259
260# Sum up the final two states.261indices_final_state = tf.stack([262input_seq_len - 1,263tf.range(b),2642 * output_seq_len,265],266axis=1) # [B, 3]267indices_penultimate_state = tf.stack([268input_seq_len - 1,269tf.range(b),2702 * output_seq_len - 1,271],272axis=1) # [B, 3]273
274final_state = tuple(275tf.gather_nd(cst, indices_final_state) for cst in ctc_state_tables) # [B]276penultimate_state = tuple(277tf.gather_nd(cst, indices_penultimate_state)278for cst in ctc_state_tables) # [B]279result = sr.add(final_state, penultimate_state) # [B]280
281# Zero out invalid losses.282return tuple(283tf.where(tf.math.is_inf(r), tf.zeros_like(r), r) for r in result) # [B]284
285
286def rnnt_semiring(287sr,288s1_inputs,289s2_inputs,290s1_seq_len,291s2_seq_len):292"""RNN-T loss for an arbitrary semiring.293
294The RNN-T dynamic programming graph stays the same, but the addition (+) and
295multiplication (*) operations are now given by the semiring.
296
297B: Batch size.
298D: Loop skewing dimension, i.e. the sum of s1 and s2.
299S1: Sequence 1 dimension (canonically, the input).
300S2: Sequence 2 dimension (canonically, the output).
301
302At each step, we consume/produce from each of the two sequences until all the
303tokens have been used. We abide by the convention that the last token is
304always from sequence 1. The RNN-T loss is a dynamic programming algorithm that
305sums the probability of every possible such sequence.
306
307The dynamic programming equation is:
308alpha[s1, s2] = alpha[s1-1, s2] (*) s1_inputs[s1-1, s2] (+)
309alpha[s1, s2-1] (*) s2_inputs[s1, s2-1].
310
311The boundary condition is:
312loss = alpha[S1, S2] (*) s1_inputs[S1, S2].
313
314Instead of two nested for loops, we do loop skewing by iterating through the
315diagonal line, D.
316
317Loop skewing for RNN-T is discussed here:
318T. Bagby, K. Rao and K. C. Sim, "Efficient Implementation of Recurrent
319Neural Network Transducer in Tensorflow," 2018 IEEE Spoken Language
320Technology Workshop (SLT), Athens, Greece, 2018, pp. 506-512.
321
322The following wikipedia article is also a helpful reference:
323https://en.wikipedia.org/wiki/Polytope_model
324
325Args:
326sr: Semiring object where each state is a tuple of tf.Tensors of shape [B,
327S1, S2].
328s1_inputs: Sequence 1 inputs to the RNN-T graph.
329s2_inputs: Sequence 2 inputs to the RNN-T graph.
330s1_seq_len: Sequence lengths for sequence 1 of shape [B].
331s2_seq_len: Sequence lengths for sequence 2 of shape [B].
332
333Returns:
334Output of the RNN-T graph, which is a tuple of tf.Tensors of shape [B].
335"""
336tf.debugging.assert_shapes([(state, ['B', 'S1', 'S2']) for state in s1_inputs337])338tf.debugging.assert_shapes([(state, ['B', 'S1', 'S2']) for state in s2_inputs339])340tf.debugging.assert_shapes([341(s1_seq_len, ['B']),342(s2_seq_len, ['B']),343])344
345# Convert inputs to tensor.346s1_inputs = tuple(tf.convert_to_tensor(i) for i in s1_inputs) # [B, S1, S2]347s2_inputs = tuple(tf.convert_to_tensor(i) for i in s2_inputs) # [B, S1, S2]348s1_seq_len = tf.convert_to_tensor(s1_seq_len) # [B]349s2_seq_len = tf.convert_to_tensor(s2_seq_len) # [B]350
351# Convert inputs to semiring inputs.352s1_inputs = sr.convert_logits(s1_inputs) # [B, S1, S2]353s2_inputs = sr.convert_logits(s2_inputs) # [B, S1, S2]354
355b, s1, s2 = py_utils.GetShape(s1_inputs[0])356d = s1 + s2 - 1357dtype = s1_inputs[0].dtype358
359# Mask invalid logit states.360s1_mask = tf.sequence_mask(361s1_seq_len, maxlen=s1)[:, :, tf.newaxis] # [B, S1, 1]362s1_mask = tf.broadcast_to(s1_mask, [b, s1, s2]) # [B, S1, S2]363additive_identity = sr.additive_identity(364shape=(b, s1, s2), dtype=dtype) # [B, S1, S2]365s2_inputs = tuple(366tf.where(s1_mask, s2_state, ai)367for s2_state, ai in zip(s2_inputs, additive_identity)) # [B, S1, S2]368
369# Skew the RNN-T table.370def _skew(y, ai_scalar):371"""Skew the loop along the dimension of sequence 1."""372# [B, S1, S2] => [D, B, S2]373y = tf.transpose(y, [0, 2, 1]) # [B, S2, S1]374y = tf.pad(375y, [[0, 0], [0, 0], [0, s2]],376constant_values=ai_scalar[0]) # [B, S2, S1+S2]377y = tf.reshape(y, [b, s2 * (s1 + s2)]) # [B, S2*(S1+S2)]378y = y[:, :s2 * d] # [B, S2*(S1+S2-1)]379y = tf.reshape(y, [b, s2, d]) # [B, S2, S1+S2-1]380y = tf.transpose(y, [2, 0, 1]) # [D, B, S2]381return y382
383additive_identity_scalar = sr.additive_identity(384shape=(1,), dtype=dtype) # [D, B, S2]385r_s1 = tuple(386_skew(s1_i, ai)387for s1_i, ai in zip(s1_inputs, additive_identity_scalar)) # [D, B, S2]388r_s2 = tuple(389_skew(s2_i, ai)390for s2_i, ai in zip(s2_inputs, additive_identity_scalar)) # [D, B, S2]391
392# Iterate through the RNN-T table.393def _shift_down_s2(y, ai):394"""Shift the sequence 2 inputs down by one time step."""395return tf.pad(396y[:, :-1], [[0, 0], [1, 0]], constant_values=ai[0, 0]) # [B, S2]397
398def _step(alpha_d, x):399s1_d, s2_d = x # [B, S2]400additive_identity = sr.additive_identity(401shape=(b, s2), dtype=dtype) # [B, S2]402a_s1_d = sr.multiply(alpha_d, s1_d) # [B, S2]403a_s2_d = sr.multiply(alpha_d, s2_d) # [B, S2]404a_s2_d = tuple(405_shift_down_s2(as2d_i, ai)406for as2d_i, ai in zip(a_s2_d, additive_identity)) # [B, S2]407path_sum = sr.add(a_s1_d, a_s2_d) # [B, S2]408return path_sum409
410# Initial condition for the loop accumulator.411multiplicative_identity = sr.multiplicative_identity(412shape=(b, 1), dtype=dtype) # [B, 1]413additive_identity = sr.additive_identity(414shape=(b, s2 - 1), dtype=dtype) # [B, S2-1]415init_d = tuple(416tf.concat([mi, ai], axis=1)417for mi, ai in zip(multiplicative_identity, additive_identity)) # [B, S2]418
419# Compute the RNN-T loss with loop skewing.420r_alpha = tf.scan(421_step,422(r_s1, r_s2),423init_d,424) # [D+1, B, S2]425indices = tf.stack([s1_seq_len + s2_seq_len - 1,426tf.range(b), s2_seq_len],427axis=1) # [B, 3]428result = tuple(tf.gather_nd(r_a, indices) for r_a in r_alpha) # [B]429
430# Zero out invalid losses from length zero sequences. These losses are invalid431# because empty sequences do not produce an alignment path.432valid_seqs = tf.math.logical_and(s1_seq_len > 0, s2_seq_len > 0) # [B]433return tuple(tf.where(valid_seqs, r, tf.zeros_like(r)) for r in result) # [B]434