google-research

Форк
0
433 строки · 15.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Neural Speech Recognition Losses."""
17

18
from lingvo import compat as tf
19
from lingvo.core import py_utils
20
import semiring
21
import utils
22

23
SeqLen = utils.SeqLen
24

25

26
def interleave_with_blank(x, blank,
27
                          axis):
28
  """Interleaves x with blanks along axis.
29

30
  E.g. if input is AAA, we want the output to be bAbAbAb.
31

32
  Args:
33
     x: Tensor of shape [..., U, ...].
34
     blank: Tensor of same shape as x except U is now 1.
35
     axis: Axis of input that U corresponds to.
36

37
  Returns:
38
    output: Tensor of same shape as x except U is now 2*U+1.
39
  """
40
  input_shape = tf.shape(x)
41
  input_rank = tf.rank(x)
42
  u = input_shape[axis]
43

44
  # Create a blanks tensor of same shape as x.
45
  blanks = tf.broadcast_to(blank, input_shape)  # [..., U, ...]
46

47
  # Interleave x with blanks.
48
  interleaved_dims = input_shape + tf.one_hot(
49
      axis, depth=input_rank, dtype=tf.int32) * u
50
  interleaved = tf.reshape(
51
      tf.stack([blanks, x], axis=axis + 1), interleaved_dims)  # [..., 2*U, ...]
52

53
  # Add an extra blank at the end.
54
  interleaved = tf.concat([interleaved, blank], axis=axis)  # [..., 2*U+1, ...]
55

56
  return interleaved
57

58

59
def ctc(input_logits, output_labels,
60
        input_seq_len, output_seq_len):
61
  """CTC loss.
62

63
  B: Batch size.
64
  T: Input sequence dimension.
65
  U: Output sequence dimension.
66
  V: Vocabulary size.
67

68
  Alex Graves's CTC paper can be found at:
69
  https://www.cs.toronto.edu/~graves/icml_2006.pdf
70

71
  The following Distill article is also a helpful reference:
72
  https://distill.pub/2017/ctc/
73

74
  Args:
75
    input_logits: Logits for input sequence of shape [B, T, V]. We assume the
76
      0th token in the vocabulary represents the blank.
77
    output_labels: Labels for output sequence of shape [B, U].
78
    input_seq_len: Sequence lengths for input sequence of shape [B].
79
    output_seq_len: Sequence lengths for output sequence of shape [B].
80

81
  Returns:
82
    CTC loss, which is a tf.Tensor of shape B.
83
  """
84
  log_sum = ctc_semiring(
85
      sr=semiring.LogSemiring(),
86
      sr_inputs=(input_logits,),
87
      output_labels=output_labels,
88
      input_seq_len=input_seq_len,
89
      output_seq_len=output_seq_len)[0]
90
  return -log_sum
91

92

93
def rnnt(s1_logits,
94
         s2_logits,
95
         s1_seq_len,
96
         s2_seq_len):
97
  """RNN-T loss for two sequences.
98

99
  B: Batch size.
100
  D: Loop skewing dimension, i.e. the sum of s1 and s2.
101
  S1: Sequence 1 dimension (canonically, the input).
102
  S2: Sequence 2 dimension (canonically, the output).
103

104
  At each step, we consume/produce from each of the two sequences until all the
105
  tokens have been used. We abide by the convention that the last token is
106
  always from sequence 1. The RNN-T loss is a dynamic programming algorithm that
107
  sums the probability of every possible such sequence.
108

109
  The dynamic programming equation is:
110
    alpha[s1, s2] = alpha[s1-1, s2] * s1_logits[s1-1, s2] +
111
                    alpha[s1, s2-1] * s2_logits[s1, s2-1].
112

113
  The boundary condition is:
114
    loss = alpha[S1, S2] * s1_logits[S1, S2].
115

116
  In practice, we work in the log space for numerical stability. Instead of
117
  two nested for loops, we do loop skewing by iterating through the diagonal
118
  line, D.
119

120
  Loop skewing for RNN-T is discussed here:
121
    T. Bagby, K. Rao and K. C. Sim, "Efficient Implementation of Recurrent
122
    Neural Network Transducer in Tensorflow," 2018 IEEE Spoken Language
123
    Technology Workshop (SLT), Athens, Greece, 2018, pp. 506-512.
124

125
  The following wikipedia article is also a helpful reference:
126
  https://en.wikipedia.org/wiki/Polytope_model
127

128
  Args:
129
    s1_logits: Logits for sequence 1 of shape [B, S1, S2]. We always end with a
130
      token from sequence 1.
131
    s2_logits: Logits for sequence 2 of shape [B, S1, S2].
132
    s1_seq_len: Sequence lengths for sequence 1 of shape [B].
133
    s2_seq_len: Sequence lengths for sequence 2 of shape [B].
134

135
  Returns:
136
    RNN-T loss, which is a tf.Tensor of shape B.
137
  """
138
  log_sum = rnnt_semiring(
139
      sr=semiring.LogSemiring(),
140
      s1_inputs=(s1_logits,),
141
      s2_inputs=(s2_logits,),
142
      s1_seq_len=s1_seq_len,
143
      s2_seq_len=s2_seq_len)[0]
144
  return -log_sum
145

146

147
def ctc_semiring(sr,
148
                 sr_inputs, output_labels,
149
                 input_seq_len,
150
                 output_seq_len):
151
  """CTC loss for an arbitrary semiring.
152

153
  The CTC dynamic programming graph stays the same, but the addition and
154
  multiplication operations are now given by the semiring.
155

156
  B: Batch size.
157
  T: Input sequence dimension.
158
  U: Output sequence dimension.
159
  V: Vocabulary size.
160

161
  Alex Graves's CTC paper can be found at:
162
  https://www.cs.toronto.edu/~graves/icml_2006.pdf
163

164
  The following Distill article is also a helpful reference:
165
  https://distill.pub/2017/ctc/
166

167
  Args:
168
    sr: Semiring object where each state is a tuple of tf.Tensors of shape [B,
169
      T, V].
170
    sr_inputs: Input to the CTC graph.
171
    output_labels: Labels for output sequence of shape [B, U].
172
    input_seq_len: Sequence lengths for input sequence of shape [B].
173
    output_seq_len: Sequence lengths for output sequence of shape [B].
174

175
  Returns:
176
    Output of the CTC graph, which is a tuple of tf.Tensors of shape [B].
177
  """
178
  tf.debugging.assert_shapes([(state, ['B', 'T', 'V']) for state in sr_inputs])
179
  tf.debugging.assert_shapes([
180
      (output_labels, ['B', 'U']),
181
      (input_seq_len, ['B']),
182
      (output_seq_len, ['B']),
183
  ])
184

185
  # Convert inputs to tensors.
186
  sr_inputs = tuple(tf.convert_to_tensor(i) for i in sr_inputs)  # [B, T, V]
187
  output_labels = tf.convert_to_tensor(output_labels)  # [B, U]
188
  input_seq_len = tf.convert_to_tensor(input_seq_len)  # [B]
189
  output_seq_len = tf.convert_to_tensor(output_seq_len)  # [B]
190

191
  b, t, v = py_utils.GetShape(sr_inputs[0])
192
  u = py_utils.GetShape(output_labels)[1]
193
  dtype = sr_inputs[0].dtype
194

195
  # Create a bitmask for the labels that returns True for the start of every
196
  # contiguous segment of repeating characters.
197
  # E.g. For AABBB, we have TFTFF. After padding with blanks (which are filled
198
  # with False), i.e. bAbAbBbBbBb, we have FTFFFTFFFFF.
199
  is_label_distinct = tf.not_equal(output_labels[:, :-1],
200
                                   output_labels[:, 1:])  # [B, U-1]
201
  is_label_distinct = tf.pad(
202
      is_label_distinct, [[0, 0], [1, 0]], constant_values=True)  # [B, U]
203
  is_label_distinct = interleave_with_blank(
204
      is_label_distinct, tf.zeros([b, 1], dtype=tf.bool), axis=1)  # [B, 2*U+1]
205

206
  # Create CTC tables of shape [B, 2*U+1, T].
207
  ctc_state_tables = []
208
  onehot_labels = tf.one_hot(output_labels, depth=v, dtype=dtype)  # [B, U, V]
209
  for state in sr_inputs:
210
    ctc_state_table = tf.einsum('buv, btv -> but', onehot_labels,
211
                                state)  # [B, U, T]
212
    blank = tf.transpose(state[:, :, :1], [0, 2, 1])  # [B, 1, T]
213
    ctc_state_table = interleave_with_blank(
214
        ctc_state_table, blank, axis=1)  # [B, 2*U+1, T]
215
    ctc_state_tables.append(ctc_state_table)
216
  ctc_state_tables = sr.convert_logits(tuple(ctc_state_tables))  # [B, 2*U+1, T]
217

218
  # Mask out invalid starting states, i.e. all but the first two.
219
  start_mask = tf.concat(
220
      [tf.ones([2], dtype=tf.bool),
221
       tf.zeros([2 * u - 1], dtype=tf.bool)],
222
      axis=0)[:, tf.newaxis]  # [2*U+1, 1]
223
  start_mask = tf.pad(
224
      start_mask, [[0, 0], [0, t - 1]], constant_values=True)  # [2*U+1, T]
225
  start_mask = tf.tile(start_mask[tf.newaxis], [b, 1, 1])  # [B, 2*U+1, T]
226

227
  additive_identity = sr.additive_identity(
228
      shape=(b, 2 * u + 1, t), dtype=dtype)  # [B, 2*U+1, T]
229
  ctc_state_tables = [
230
      tf.where(start_mask, cst, ai)
231
      for cst, ai in zip(ctc_state_tables, additive_identity)
232
  ]  # [B, 2*U+1, T]
233

234
  # Iterate through the CTC tables.
235
  ctc_state_tables = tuple(
236
      tf.transpose(cst, [2, 0, 1]) for cst in ctc_state_tables)  # [T, B, 2*U+1]
237

238
  def _generate_transitions(acc, ai):
239
    """Generates CTC state transitions."""
240
    plus_zero = acc  # [B, 2*U+1]
241
    plus_one = tf.pad(
242
        acc[:, :-1], [[0, 0], [1, 0]], constant_values=ai[0, 0])  # [B, 2*U+1]
243
    plus_two = tf.pad(
244
        acc[:, :-2], [[0, 0], [2, 0]], constant_values=ai[0, 0])  # [B, 2*U+1]
245
    plus_two = tf.where(is_label_distinct, plus_two, ai)  # [B, 2*U+1]
246
    return [plus_zero, plus_one, plus_two]
247

248
  def _step(acc, x):
249
    additive_identity = sr.additive_identity(
250
        shape=(b, 2 * u + 1), dtype=dtype)  # [B, 2*U+1]
251
    path_sum = tuple(
252
        _generate_transitions(acc_i, ai)
253
        for acc_i, ai in zip(acc, additive_identity))  # [B, 2*U+1]
254
    path_sum = sr.add_list(utils.tuple_to_list(path_sum))  # [B, 2*U+1]
255
    new_acc = sr.multiply(path_sum, x)  # [B, 2*U+1]
256
    return new_acc
257

258
  ctc_state_tables = tf.scan(_step, ctc_state_tables)  # [T, B, 2*U+1]
259

260
  # Sum up the final two states.
261
  indices_final_state = tf.stack([
262
      input_seq_len - 1,
263
      tf.range(b),
264
      2 * output_seq_len,
265
  ],
266
                                 axis=1)  # [B, 3]
267
  indices_penultimate_state = tf.stack([
268
      input_seq_len - 1,
269
      tf.range(b),
270
      2 * output_seq_len - 1,
271
  ],
272
                                       axis=1)  # [B, 3]
273

274
  final_state = tuple(
275
      tf.gather_nd(cst, indices_final_state) for cst in ctc_state_tables)  # [B]
276
  penultimate_state = tuple(
277
      tf.gather_nd(cst, indices_penultimate_state)
278
      for cst in ctc_state_tables)  # [B]
279
  result = sr.add(final_state, penultimate_state)  # [B]
280

281
  # Zero out invalid losses.
282
  return tuple(
283
      tf.where(tf.math.is_inf(r), tf.zeros_like(r), r) for r in result)  # [B]
284

285

286
def rnnt_semiring(
287
    sr,
288
    s1_inputs,
289
    s2_inputs,
290
    s1_seq_len,
291
    s2_seq_len):
292
  """RNN-T loss for an arbitrary semiring.
293

294
  The RNN-T dynamic programming graph stays the same, but the addition (+) and
295
  multiplication (*) operations are now given by the semiring.
296

297
  B: Batch size.
298
  D: Loop skewing dimension, i.e. the sum of s1 and s2.
299
  S1: Sequence 1 dimension (canonically, the input).
300
  S2: Sequence 2 dimension (canonically, the output).
301

302
  At each step, we consume/produce from each of the two sequences until all the
303
  tokens have been used. We abide by the convention that the last token is
304
  always from sequence 1. The RNN-T loss is a dynamic programming algorithm that
305
  sums the probability of every possible such sequence.
306

307
  The dynamic programming equation is:
308
    alpha[s1, s2] = alpha[s1-1, s2] (*) s1_inputs[s1-1, s2] (+)
309
                    alpha[s1, s2-1] (*) s2_inputs[s1, s2-1].
310

311
  The boundary condition is:
312
    loss = alpha[S1, S2] (*) s1_inputs[S1, S2].
313

314
  Instead of two nested for loops, we do loop skewing by iterating through the
315
  diagonal line, D.
316

317
  Loop skewing for RNN-T is discussed here:
318
    T. Bagby, K. Rao and K. C. Sim, "Efficient Implementation of Recurrent
319
    Neural Network Transducer in Tensorflow," 2018 IEEE Spoken Language
320
    Technology Workshop (SLT), Athens, Greece, 2018, pp. 506-512.
321

322
  The following wikipedia article is also a helpful reference:
323
  https://en.wikipedia.org/wiki/Polytope_model
324

325
  Args:
326
    sr: Semiring object where each state is a tuple of tf.Tensors of shape [B,
327
      S1, S2].
328
    s1_inputs: Sequence 1 inputs to the RNN-T graph.
329
    s2_inputs: Sequence 2 inputs to the RNN-T graph.
330
    s1_seq_len: Sequence lengths for sequence 1 of shape [B].
331
    s2_seq_len: Sequence lengths for sequence 2 of shape [B].
332

333
  Returns:
334
    Output of the RNN-T graph, which is a tuple of tf.Tensors of shape [B].
335
  """
336
  tf.debugging.assert_shapes([(state, ['B', 'S1', 'S2']) for state in s1_inputs
337
                             ])
338
  tf.debugging.assert_shapes([(state, ['B', 'S1', 'S2']) for state in s2_inputs
339
                             ])
340
  tf.debugging.assert_shapes([
341
      (s1_seq_len, ['B']),
342
      (s2_seq_len, ['B']),
343
  ])
344

345
  # Convert inputs to tensor.
346
  s1_inputs = tuple(tf.convert_to_tensor(i) for i in s1_inputs)  # [B, S1, S2]
347
  s2_inputs = tuple(tf.convert_to_tensor(i) for i in s2_inputs)  # [B, S1, S2]
348
  s1_seq_len = tf.convert_to_tensor(s1_seq_len)  # [B]
349
  s2_seq_len = tf.convert_to_tensor(s2_seq_len)  # [B]
350

351
  # Convert inputs to semiring inputs.
352
  s1_inputs = sr.convert_logits(s1_inputs)  # [B, S1, S2]
353
  s2_inputs = sr.convert_logits(s2_inputs)  # [B, S1, S2]
354

355
  b, s1, s2 = py_utils.GetShape(s1_inputs[0])
356
  d = s1 + s2 - 1
357
  dtype = s1_inputs[0].dtype
358

359
  # Mask invalid logit states.
360
  s1_mask = tf.sequence_mask(
361
      s1_seq_len, maxlen=s1)[:, :, tf.newaxis]  # [B, S1, 1]
362
  s1_mask = tf.broadcast_to(s1_mask, [b, s1, s2])  # [B, S1, S2]
363
  additive_identity = sr.additive_identity(
364
      shape=(b, s1, s2), dtype=dtype)  # [B, S1, S2]
365
  s2_inputs = tuple(
366
      tf.where(s1_mask, s2_state, ai)
367
      for s2_state, ai in zip(s2_inputs, additive_identity))  # [B, S1, S2]
368

369
  # Skew the RNN-T table.
370
  def _skew(y, ai_scalar):
371
    """Skew the loop along the dimension of sequence 1."""
372
    # [B, S1, S2] => [D, B, S2]
373
    y = tf.transpose(y, [0, 2, 1])  # [B, S2, S1]
374
    y = tf.pad(
375
        y, [[0, 0], [0, 0], [0, s2]],
376
        constant_values=ai_scalar[0])  # [B, S2, S1+S2]
377
    y = tf.reshape(y, [b, s2 * (s1 + s2)])  # [B, S2*(S1+S2)]
378
    y = y[:, :s2 * d]  # [B, S2*(S1+S2-1)]
379
    y = tf.reshape(y, [b, s2, d])  # [B, S2, S1+S2-1]
380
    y = tf.transpose(y, [2, 0, 1])  # [D, B, S2]
381
    return y
382

383
  additive_identity_scalar = sr.additive_identity(
384
      shape=(1,), dtype=dtype)  # [D, B, S2]
385
  r_s1 = tuple(
386
      _skew(s1_i, ai)
387
      for s1_i, ai in zip(s1_inputs, additive_identity_scalar))  # [D, B, S2]
388
  r_s2 = tuple(
389
      _skew(s2_i, ai)
390
      for s2_i, ai in zip(s2_inputs, additive_identity_scalar))  # [D, B, S2]
391

392
  # Iterate through the RNN-T table.
393
  def _shift_down_s2(y, ai):
394
    """Shift the sequence 2 inputs down by one time step."""
395
    return tf.pad(
396
        y[:, :-1], [[0, 0], [1, 0]], constant_values=ai[0, 0])  # [B, S2]
397

398
  def _step(alpha_d, x):
399
    s1_d, s2_d = x  # [B, S2]
400
    additive_identity = sr.additive_identity(
401
        shape=(b, s2), dtype=dtype)  # [B, S2]
402
    a_s1_d = sr.multiply(alpha_d, s1_d)  # [B, S2]
403
    a_s2_d = sr.multiply(alpha_d, s2_d)  # [B, S2]
404
    a_s2_d = tuple(
405
        _shift_down_s2(as2d_i, ai)
406
        for as2d_i, ai in zip(a_s2_d, additive_identity))  # [B, S2]
407
    path_sum = sr.add(a_s1_d, a_s2_d)  # [B, S2]
408
    return path_sum
409

410
  # Initial condition for the loop accumulator.
411
  multiplicative_identity = sr.multiplicative_identity(
412
      shape=(b, 1), dtype=dtype)  # [B, 1]
413
  additive_identity = sr.additive_identity(
414
      shape=(b, s2 - 1), dtype=dtype)  # [B, S2-1]
415
  init_d = tuple(
416
      tf.concat([mi, ai], axis=1)
417
      for mi, ai in zip(multiplicative_identity, additive_identity))  # [B, S2]
418

419
  # Compute the RNN-T loss with loop skewing.
420
  r_alpha = tf.scan(
421
      _step,
422
      (r_s1, r_s2),
423
      init_d,
424
  )  # [D+1, B, S2]
425
  indices = tf.stack([s1_seq_len + s2_seq_len - 1,
426
                      tf.range(b), s2_seq_len],
427
                     axis=1)  # [B, 3]
428
  result = tuple(tf.gather_nd(r_a, indices) for r_a in r_alpha)  # [B]
429

430
  # Zero out invalid losses from length zero sequences. These losses are invalid
431
  # because empty sequences do not produce an alignment path.
432
  valid_seqs = tf.math.logical_and(s1_seq_len > 0, s2_seq_len > 0)  # [B]
433
  return tuple(tf.where(valid_seqs, r, tf.zeros_like(r)) for r in result)  # [B]
434

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.