google-research

Форк
0
371 строка · 12.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for asr_loss."""
17

18
from absl.testing import parameterized
19
import asr_loss
20
from lingvo import compat as tf
21
import numpy as np
22
import semiring
23
import utils
24

25

26
class UtilsTest(tf.test.TestCase):
27

28
  def testInterleaveWithBlank(self):
29
    """Enumerate mock inputs by hand and compare."""
30
    x = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
31
    blank_1 = tf.constant([[0.0], [0.0]])
32
    blank_2 = tf.constant([[0.5, 0.5, 0.5]])
33

34
    output_1 = asr_loss.interleave_with_blank(x, blank_1, axis=1)
35
    output_2 = asr_loss.interleave_with_blank(x, blank_2, axis=0)
36

37
    expected_output_1 = tf.constant([[0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0],
38
                                     [0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0]])
39
    expected_output_2 = tf.constant([[0.5, 0.5, 0.5], [1.0, 2.0, 3.0],
40
                                     [0.5, 0.5, 0.5], [4.0, 5.0, 6.0],
41
                                     [0.5, 0.5, 0.5]])
42

43
    self.assertAllClose(output_1, expected_output_1)
44
    self.assertAllClose(output_2, expected_output_2)
45

46

47
class ASRLossTest(tf.test.TestCase):
48

49
  def testCTCByHand(self):
50
    """Enumerate a very simple lattice by hand and compare."""
51
    input_logits = np.array([[
52
        [-1.0, -2.0, -3.0],
53
        [-4.0, -5.0, -6.0],
54
        [-7.0, -8.0, -9.0],
55
        [-10.0, -11.0, -12.0],
56
    ]])
57
    output_labels = np.array([[1, 2, 2]])
58

59
    loss = asr_loss.ctc(
60
        input_logits=input_logits,
61
        output_labels=output_labels,
62
        input_seq_len=[4],
63
        output_seq_len=[3],
64
    )
65

66
    by_hand = -tf.reduce_logsumexp(
67
        [
68
            np.sum([-2.0, -6.0, -7.0, -12.0]),  # (1, 2, b, 2)
69
        ],
70
        keepdims=True)
71

72
    self.assertAllClose(loss, by_hand)
73

74
    # Check that invalid losses are zero-ed out.
75
    loss = asr_loss.ctc(
76
        input_logits=input_logits,
77
        output_labels=output_labels,
78
        input_seq_len=[3],
79
        output_seq_len=[3],
80
    )
81

82
    by_hand = np.array([0.0])
83

84
    self.assertAllClose(loss, by_hand)
85

86
    # Check that the unused logits are masked out.
87
    loss = asr_loss.ctc(
88
        input_logits=input_logits,
89
        output_labels=output_labels,
90
        input_seq_len=[3],
91
        output_seq_len=[2],
92
    )
93

94
    by_hand = -tf.reduce_logsumexp(
95
        [
96
            np.sum([-1.0, -5.0, -9.0]),  # (b, 1, 2)
97
            np.sum([-2.0, -4.0, -9.0]),  # (1, b, 2)
98
            np.sum([-2.0, -5.0, -9.0]),  # (1, 1, 2)
99
            np.sum([-2.0, -6.0, -7.0]),  # (1, 2, b)
100
            np.sum([-2.0, -6.0, -9.0]),  # (1, 2, 2)
101
        ],
102
        keepdims=True)
103

104
    self.assertAllClose(loss, by_hand)
105

106
  def testRNNTByHand(self):
107
    """Enumerate a very simple lattice by hand and compare."""
108

109
    s1_logits = np.array([[
110
        [-1.0, -2.0, -3.0],
111
        [-4.0, -5.0, -6.0],
112
        [0.0, 0.0, -13.0],
113
    ]])
114
    s2_logits = np.array([[
115
        [-7.0, -8.0, 0.0],
116
        [-9.0, -10.0, 0.0],
117
        [-11.0, -12.0, 0.0],
118
    ]])
119

120
    loss = asr_loss.rnnt(
121
        s1_logits=s1_logits,
122
        s2_logits=s2_logits,
123
        s1_seq_len=[3],
124
        s2_seq_len=[2],
125
    )
126

127
    by_hand = -tf.reduce_logsumexp(
128
        [
129
            np.sum([-1.0, -4.0, -11.0, -12.0, -13.0]),  # (S1, S1, S2, S2, S1)
130
            np.sum([-1.0, -9.0, -5.0, -12.0, -13.0]),  # (S1, S2, S1, S2, S1)
131
            np.sum([-1.0, -9.0, -10.0, -6.0, -13.0]),  # (S1, S2, S2, S1, S1)
132
            np.sum([-7.0, -2.0, -5.0, -12.0, -13.0]),  # (S2, S1, S1, S2, S1)
133
            np.sum([-7.0, -2.0, -10.0, -6.0, -13.0]),  # (S2, S1, S2, S1, S1)
134
            np.sum([-7.0, -8.0, -3.0, -6.0, -13.0]),  # (S2, S2, S1, S1, S1)
135
        ],
136
        keepdims=True)
137

138
    self.assertAllClose(loss, by_hand)
139

140
    # Check that invalid losses are zero-ed out.
141
    loss_1 = asr_loss.rnnt(
142
        s1_logits=s1_logits,
143
        s2_logits=s2_logits,
144
        s1_seq_len=[0],
145
        s2_seq_len=[2],
146
    )
147
    loss_2 = asr_loss.rnnt(
148
        s1_logits=s1_logits,
149
        s2_logits=s2_logits,
150
        s1_seq_len=[3],
151
        s2_seq_len=[0],
152
    )
153
    zeros = np.array([0.0])
154

155
    self.assertAllClose(loss_1, zeros)
156
    self.assertAllClose(loss_2, zeros)
157

158
    # Check that the unused logits are masked out.
159
    s1_logits = np.array([[
160
        [-1.0, -2.0, -3.0],
161
        [-4.0, -5.0, -6.0],
162
        [1.0, 1.0, -13.0],
163
    ]])
164
    s2_logits = np.array([[
165
        [-7.0, -8.0, 1.0],
166
        [-9.0, -10.0, 1.0],
167
        [-11.0, -12.0, 1.0],
168
    ]])
169

170
    loss = asr_loss.rnnt(
171
        s1_logits=s1_logits,
172
        s2_logits=s2_logits,
173
        s1_seq_len=[3],
174
        s2_seq_len=[2])
175

176
    self.assertAllClose(loss, by_hand)
177

178

179
class SemiringLossTest(parameterized.TestCase, tf.test.TestCase):
180

181
  def setUp(self):
182
    super().setUp()
183

184
    # Set up CTC inputs.
185
    self.ctc_logits_p = np.array([[
186
        [-1.0, -2.0, -3.0],
187
        [-4.0, -5.0, -6.0],
188
        [-7.0, -8.0, -9.0],
189
        [-10.0, -11.0, -12.0],
190
    ]])
191
    self.ctc_logits_q = np.array([[
192
        [-13.0, -14.0, -15.0],
193
        [-16.0, -17.0, -18.0],
194
        [-19.0, -20.0, -21.0],
195
        [-22.0, -23.0, -24.0],
196
    ]])
197
    self.ctc_short_paths_p = np.array([
198
        np.sum([-2.0, -6.0, -7.0, -12.0])  # (1, 2, b, 2)
199
    ])
200
    self.ctc_short_paths_q = np.array([
201
        np.sum([-14.0, -18.0, -19.0, -24.0]),  # (1, 2, b, 2)
202
    ])
203
    self.ctc_long_paths_p = np.array([
204
        np.sum([-1.0, -5.0, -9.0]),  # (b, 1, 2)
205
        np.sum([-2.0, -4.0, -9.0]),  # (1, b, 2)
206
        np.sum([-2.0, -5.0, -9.0]),  # (1, 1, 2)
207
        np.sum([-2.0, -6.0, -7.0]),  # (1, 2, b)
208
        np.sum([-2.0, -6.0, -9.0]),  # (1, 2, 2)
209
    ])
210
    self.ctc_long_paths_q = np.array([
211
        np.sum([-13.0, -17.0, -21.0]),  # (b, 1, 2)
212
        np.sum([-14.0, -16.0, -21.0]),  # (1, b, 2)
213
        np.sum([-14.0, -17.0, -21.0]),  # (1, 1, 2)
214
        np.sum([-14.0, -18.0, -19.0]),  # (1, 2, b)
215
        np.sum([-14.0, -18.0, -21.0]),  # (1, 2, 2)
216
    ])
217
    self.output_labels = np.array([[1, 2, 2]])
218
    self.input_seq_len = [4]
219
    self.output_seq_len = [3]
220
    self.invalid_input_seq_len = [3]
221
    self.invalid_output_seq_len = [3]
222
    self.unused_input_seq_len = [3]
223
    self.unused_output_seq_len = [2]
224

225
    # Set up RNN-T inputs.
226
    self.rnnt_s1_logits_p = np.array([[
227
        [-1.0, -2.0, -3.0],
228
        [-4.0, -5.0, -6.0],
229
        [0.0, 0.0, -13.0],
230
    ]])
231
    self.rnnt_s2_logits_p = np.array([[
232
        [-7.0, -8.0, 0.0],
233
        [-9.0, -10.0, 0.0],
234
        [-11.0, -12.0, 0.0],
235
    ]])
236
    self.rnnt_s1_logits_q = np.array([[
237
        [-14.0, -15.0, -16.0],
238
        [-17.0, -18.0, -19.0],
239
        [0.0, 0.0, -26.0],
240
    ]])
241
    self.rnnt_s2_logits_q = np.array([[
242
        [-20.0, -21.0, 0.0],
243
        [-22.0, -23.0, 0.0],
244
        [-24.0, -25.0, 0.0],
245
    ]])
246
    self.rnnt_paths_p = np.array([
247
        np.sum([-1.0, -4.0, -11.0, -12.0, -13.0]),  # (S1, S1, S2, S2, S1)
248
        np.sum([-1.0, -9.0, -5.0, -12.0, -13.0]),  # (S1, S2, S1, S2, S1)
249
        np.sum([-1.0, -9.0, -10.0, -6.0, -13.0]),  # (S1, S2, S2, S1, S1)
250
        np.sum([-7.0, -2.0, -5.0, -12.0, -13.0]),  # (S2, S1, S1, S2, S1)
251
        np.sum([-7.0, -2.0, -10.0, -6.0, -13.0]),  # (S2, S1, S2, S1, S1)
252
        np.sum([-7.0, -8.0, -3.0, -6.0, -13.0]),  # (S2, S2, S1, S1, S1)
253
    ])
254
    self.rnnt_paths_q = np.array([
255
        np.sum([-14.0, -17.0, -24.0, -25.0, -26.0]),  # (S1, S1, S2, S2, S1)
256
        np.sum([-14.0, -22.0, -18.0, -25.0, -26.0]),  # (S1, S2, S1, S2, S1)
257
        np.sum([-14.0, -22.0, -23.0, -19.0, -26.0]),  # (S1, S2, S2, S1, S1)
258
        np.sum([-20.0, -15.0, -18.0, -25.0, -26.0]),  # (S2, S1, S1, S2, S1)
259
        np.sum([-20.0, -15.0, -23.0, -19.0, -26.0]),  # (S2, S1, S2, S1, S1)
260
        np.sum([-20.0, -21.0, -16.0, -19.0, -26.0]),  # (S2, S2, S1, S1, S1)
261
    ])
262
    self.s1_seq_len = [3]
263
    self.s2_seq_len = [2]
264

265
  def ComputeLossByHand(self, sr_name, paths_p, paths_q):
266
    """Helper function to compute loss manually given the paths."""
267
    logp = tf.reduce_logsumexp(paths_p, keepdims=True)
268
    logq = tf.reduce_logsumexp(paths_q, keepdims=True)
269
    logminusplogq = tf.reduce_logsumexp(
270
        utils.logminus(paths_p, paths_q), keepdims=True)
271
    logminusqlogq = tf.reduce_logsumexp(
272
        utils.logminus(paths_q, paths_q), keepdims=True)
273
    logminusqlogp = tf.reduce_logsumexp(
274
        utils.logminus(paths_q, paths_p), keepdims=True)
275
    if sr_name == 'logentropy':
276
      return (logp, logminusplogq)
277
    elif sr_name == 'logreversekl':
278
      return (logp, logq, logminusqlogq, logminusqlogp)
279

280
  @parameterized.parameters([
281
      ('logentropy', semiring.LogEntropySemiring()),
282
      ('logreversekl', semiring.LogReverseKLSemiring()),
283
  ])
284
  def testCTCSemiring(self, sr_name, sr):
285
    loss = asr_loss.ctc_semiring(
286
        sr=sr,
287
        sr_inputs=(self.ctc_logits_p, self.ctc_logits_q),
288
        output_labels=self.output_labels,
289
        input_seq_len=self.input_seq_len,
290
        output_seq_len=self.output_seq_len)
291
    by_hand = self.ComputeLossByHand(sr_name, self.ctc_short_paths_p,
292
                                     self.ctc_short_paths_q)
293
    self.assertAllClose(loss, by_hand, atol=1e-37)
294

295
    # Check that invalid losses are zero-ed out.
296
    loss = asr_loss.ctc_semiring(
297
        sr=sr,
298
        sr_inputs=(self.ctc_logits_p, self.ctc_logits_q),
299
        output_labels=self.output_labels,
300
        input_seq_len=self.invalid_input_seq_len,
301
        output_seq_len=self.invalid_output_seq_len)
302
    for l in loss:
303
      self.assertAllClose(l, np.array([0.0]), atol=1e-37)
304

305
    # Check that the unused logits are masked out.
306
    loss = asr_loss.ctc_semiring(
307
        sr=sr,
308
        sr_inputs=(self.ctc_logits_p, self.ctc_logits_q),
309
        output_labels=self.output_labels,
310
        input_seq_len=self.unused_input_seq_len,
311
        output_seq_len=self.unused_output_seq_len)
312
    by_hand = self.ComputeLossByHand(sr_name, self.ctc_long_paths_p,
313
                                     self.ctc_long_paths_q)
314
    self.assertAllClose(loss, by_hand, atol=1e-37)
315

316
  @parameterized.parameters([
317
      ('logentropy', semiring.LogEntropySemiring()),
318
      ('logreversekl', semiring.LogReverseKLSemiring()),
319
  ])
320
  def testRNNTSemiring(self, sr_name, sr):
321
    loss = asr_loss.rnnt_semiring(
322
        sr=sr,
323
        s1_inputs=(self.rnnt_s1_logits_p, self.rnnt_s1_logits_q),
324
        s2_inputs=(self.rnnt_s2_logits_p, self.rnnt_s2_logits_q),
325
        s1_seq_len=self.s1_seq_len,
326
        s2_seq_len=self.s2_seq_len)
327
    by_hand = self.ComputeLossByHand(sr_name, self.rnnt_paths_p,
328
                                     self.rnnt_paths_q)
329
    self.assertAllClose(loss, by_hand, atol=1e-37)
330

331
    # Check that invalid losses are zero-ed out.
332
    loss_1 = asr_loss.rnnt_semiring(
333
        sr=sr,
334
        s1_inputs=(self.rnnt_s1_logits_p, self.rnnt_s1_logits_q),
335
        s2_inputs=(self.rnnt_s2_logits_p, self.rnnt_s2_logits_q),
336
        s1_seq_len=[0],
337
        s2_seq_len=self.s2_seq_len)
338
    loss_2 = asr_loss.rnnt_semiring(
339
        sr=sr,
340
        s1_inputs=(self.rnnt_s1_logits_p, self.rnnt_s1_logits_q),
341
        s2_inputs=(self.rnnt_s2_logits_p, self.rnnt_s2_logits_q),
342
        s1_seq_len=self.s1_seq_len,
343
        s2_seq_len=[0])
344
    zeros = tf.zeros_like(by_hand)
345

346
    self.assertAllClose(loss_1, zeros)
347
    self.assertAllClose(loss_2, zeros)
348

349
    # Check that the unused logits are masked out.
350
    rnnt_s1_logits_p = np.where(self.rnnt_s1_logits_p == 0.0, 1.23,
351
                                self.rnnt_s1_logits_p)
352
    rnnt_s1_logits_q = np.where(self.rnnt_s1_logits_q == 0.0, 1.23,
353
                                self.rnnt_s1_logits_q)
354
    rnnt_s2_logits_p = np.where(self.rnnt_s2_logits_p == 0.0, 1.23,
355
                                self.rnnt_s2_logits_p)
356
    rnnt_s2_logits_q = np.where(self.rnnt_s2_logits_q == 0.0, 1.23,
357
                                self.rnnt_s2_logits_q)
358

359
    loss = asr_loss.rnnt_semiring(
360
        sr=sr,
361
        s1_inputs=(rnnt_s1_logits_p, rnnt_s1_logits_q),
362
        s2_inputs=(rnnt_s2_logits_p, rnnt_s2_logits_q),
363
        s1_seq_len=self.s1_seq_len,
364
        s2_seq_len=self.s2_seq_len)
365
    by_hand = self.ComputeLossByHand(sr_name, self.rnnt_paths_p,
366
                                     self.rnnt_paths_q)
367
    self.assertAllClose(loss, by_hand, atol=1e-37)
368

369

370
if __name__ == '__main__':
371
  tf.test.main()
372

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.