google-research
371 строка · 12.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for asr_loss."""
17
18from absl.testing import parameterized19import asr_loss20from lingvo import compat as tf21import numpy as np22import semiring23import utils24
25
26class UtilsTest(tf.test.TestCase):27
28def testInterleaveWithBlank(self):29"""Enumerate mock inputs by hand and compare."""30x = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])31blank_1 = tf.constant([[0.0], [0.0]])32blank_2 = tf.constant([[0.5, 0.5, 0.5]])33
34output_1 = asr_loss.interleave_with_blank(x, blank_1, axis=1)35output_2 = asr_loss.interleave_with_blank(x, blank_2, axis=0)36
37expected_output_1 = tf.constant([[0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0],38[0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0]])39expected_output_2 = tf.constant([[0.5, 0.5, 0.5], [1.0, 2.0, 3.0],40[0.5, 0.5, 0.5], [4.0, 5.0, 6.0],41[0.5, 0.5, 0.5]])42
43self.assertAllClose(output_1, expected_output_1)44self.assertAllClose(output_2, expected_output_2)45
46
47class ASRLossTest(tf.test.TestCase):48
49def testCTCByHand(self):50"""Enumerate a very simple lattice by hand and compare."""51input_logits = np.array([[52[-1.0, -2.0, -3.0],53[-4.0, -5.0, -6.0],54[-7.0, -8.0, -9.0],55[-10.0, -11.0, -12.0],56]])57output_labels = np.array([[1, 2, 2]])58
59loss = asr_loss.ctc(60input_logits=input_logits,61output_labels=output_labels,62input_seq_len=[4],63output_seq_len=[3],64)65
66by_hand = -tf.reduce_logsumexp(67[68np.sum([-2.0, -6.0, -7.0, -12.0]), # (1, 2, b, 2)69],70keepdims=True)71
72self.assertAllClose(loss, by_hand)73
74# Check that invalid losses are zero-ed out.75loss = asr_loss.ctc(76input_logits=input_logits,77output_labels=output_labels,78input_seq_len=[3],79output_seq_len=[3],80)81
82by_hand = np.array([0.0])83
84self.assertAllClose(loss, by_hand)85
86# Check that the unused logits are masked out.87loss = asr_loss.ctc(88input_logits=input_logits,89output_labels=output_labels,90input_seq_len=[3],91output_seq_len=[2],92)93
94by_hand = -tf.reduce_logsumexp(95[96np.sum([-1.0, -5.0, -9.0]), # (b, 1, 2)97np.sum([-2.0, -4.0, -9.0]), # (1, b, 2)98np.sum([-2.0, -5.0, -9.0]), # (1, 1, 2)99np.sum([-2.0, -6.0, -7.0]), # (1, 2, b)100np.sum([-2.0, -6.0, -9.0]), # (1, 2, 2)101],102keepdims=True)103
104self.assertAllClose(loss, by_hand)105
106def testRNNTByHand(self):107"""Enumerate a very simple lattice by hand and compare."""108
109s1_logits = np.array([[110[-1.0, -2.0, -3.0],111[-4.0, -5.0, -6.0],112[0.0, 0.0, -13.0],113]])114s2_logits = np.array([[115[-7.0, -8.0, 0.0],116[-9.0, -10.0, 0.0],117[-11.0, -12.0, 0.0],118]])119
120loss = asr_loss.rnnt(121s1_logits=s1_logits,122s2_logits=s2_logits,123s1_seq_len=[3],124s2_seq_len=[2],125)126
127by_hand = -tf.reduce_logsumexp(128[129np.sum([-1.0, -4.0, -11.0, -12.0, -13.0]), # (S1, S1, S2, S2, S1)130np.sum([-1.0, -9.0, -5.0, -12.0, -13.0]), # (S1, S2, S1, S2, S1)131np.sum([-1.0, -9.0, -10.0, -6.0, -13.0]), # (S1, S2, S2, S1, S1)132np.sum([-7.0, -2.0, -5.0, -12.0, -13.0]), # (S2, S1, S1, S2, S1)133np.sum([-7.0, -2.0, -10.0, -6.0, -13.0]), # (S2, S1, S2, S1, S1)134np.sum([-7.0, -8.0, -3.0, -6.0, -13.0]), # (S2, S2, S1, S1, S1)135],136keepdims=True)137
138self.assertAllClose(loss, by_hand)139
140# Check that invalid losses are zero-ed out.141loss_1 = asr_loss.rnnt(142s1_logits=s1_logits,143s2_logits=s2_logits,144s1_seq_len=[0],145s2_seq_len=[2],146)147loss_2 = asr_loss.rnnt(148s1_logits=s1_logits,149s2_logits=s2_logits,150s1_seq_len=[3],151s2_seq_len=[0],152)153zeros = np.array([0.0])154
155self.assertAllClose(loss_1, zeros)156self.assertAllClose(loss_2, zeros)157
158# Check that the unused logits are masked out.159s1_logits = np.array([[160[-1.0, -2.0, -3.0],161[-4.0, -5.0, -6.0],162[1.0, 1.0, -13.0],163]])164s2_logits = np.array([[165[-7.0, -8.0, 1.0],166[-9.0, -10.0, 1.0],167[-11.0, -12.0, 1.0],168]])169
170loss = asr_loss.rnnt(171s1_logits=s1_logits,172s2_logits=s2_logits,173s1_seq_len=[3],174s2_seq_len=[2])175
176self.assertAllClose(loss, by_hand)177
178
179class SemiringLossTest(parameterized.TestCase, tf.test.TestCase):180
181def setUp(self):182super().setUp()183
184# Set up CTC inputs.185self.ctc_logits_p = np.array([[186[-1.0, -2.0, -3.0],187[-4.0, -5.0, -6.0],188[-7.0, -8.0, -9.0],189[-10.0, -11.0, -12.0],190]])191self.ctc_logits_q = np.array([[192[-13.0, -14.0, -15.0],193[-16.0, -17.0, -18.0],194[-19.0, -20.0, -21.0],195[-22.0, -23.0, -24.0],196]])197self.ctc_short_paths_p = np.array([198np.sum([-2.0, -6.0, -7.0, -12.0]) # (1, 2, b, 2)199])200self.ctc_short_paths_q = np.array([201np.sum([-14.0, -18.0, -19.0, -24.0]), # (1, 2, b, 2)202])203self.ctc_long_paths_p = np.array([204np.sum([-1.0, -5.0, -9.0]), # (b, 1, 2)205np.sum([-2.0, -4.0, -9.0]), # (1, b, 2)206np.sum([-2.0, -5.0, -9.0]), # (1, 1, 2)207np.sum([-2.0, -6.0, -7.0]), # (1, 2, b)208np.sum([-2.0, -6.0, -9.0]), # (1, 2, 2)209])210self.ctc_long_paths_q = np.array([211np.sum([-13.0, -17.0, -21.0]), # (b, 1, 2)212np.sum([-14.0, -16.0, -21.0]), # (1, b, 2)213np.sum([-14.0, -17.0, -21.0]), # (1, 1, 2)214np.sum([-14.0, -18.0, -19.0]), # (1, 2, b)215np.sum([-14.0, -18.0, -21.0]), # (1, 2, 2)216])217self.output_labels = np.array([[1, 2, 2]])218self.input_seq_len = [4]219self.output_seq_len = [3]220self.invalid_input_seq_len = [3]221self.invalid_output_seq_len = [3]222self.unused_input_seq_len = [3]223self.unused_output_seq_len = [2]224
225# Set up RNN-T inputs.226self.rnnt_s1_logits_p = np.array([[227[-1.0, -2.0, -3.0],228[-4.0, -5.0, -6.0],229[0.0, 0.0, -13.0],230]])231self.rnnt_s2_logits_p = np.array([[232[-7.0, -8.0, 0.0],233[-9.0, -10.0, 0.0],234[-11.0, -12.0, 0.0],235]])236self.rnnt_s1_logits_q = np.array([[237[-14.0, -15.0, -16.0],238[-17.0, -18.0, -19.0],239[0.0, 0.0, -26.0],240]])241self.rnnt_s2_logits_q = np.array([[242[-20.0, -21.0, 0.0],243[-22.0, -23.0, 0.0],244[-24.0, -25.0, 0.0],245]])246self.rnnt_paths_p = np.array([247np.sum([-1.0, -4.0, -11.0, -12.0, -13.0]), # (S1, S1, S2, S2, S1)248np.sum([-1.0, -9.0, -5.0, -12.0, -13.0]), # (S1, S2, S1, S2, S1)249np.sum([-1.0, -9.0, -10.0, -6.0, -13.0]), # (S1, S2, S2, S1, S1)250np.sum([-7.0, -2.0, -5.0, -12.0, -13.0]), # (S2, S1, S1, S2, S1)251np.sum([-7.0, -2.0, -10.0, -6.0, -13.0]), # (S2, S1, S2, S1, S1)252np.sum([-7.0, -8.0, -3.0, -6.0, -13.0]), # (S2, S2, S1, S1, S1)253])254self.rnnt_paths_q = np.array([255np.sum([-14.0, -17.0, -24.0, -25.0, -26.0]), # (S1, S1, S2, S2, S1)256np.sum([-14.0, -22.0, -18.0, -25.0, -26.0]), # (S1, S2, S1, S2, S1)257np.sum([-14.0, -22.0, -23.0, -19.0, -26.0]), # (S1, S2, S2, S1, S1)258np.sum([-20.0, -15.0, -18.0, -25.0, -26.0]), # (S2, S1, S1, S2, S1)259np.sum([-20.0, -15.0, -23.0, -19.0, -26.0]), # (S2, S1, S2, S1, S1)260np.sum([-20.0, -21.0, -16.0, -19.0, -26.0]), # (S2, S2, S1, S1, S1)261])262self.s1_seq_len = [3]263self.s2_seq_len = [2]264
265def ComputeLossByHand(self, sr_name, paths_p, paths_q):266"""Helper function to compute loss manually given the paths."""267logp = tf.reduce_logsumexp(paths_p, keepdims=True)268logq = tf.reduce_logsumexp(paths_q, keepdims=True)269logminusplogq = tf.reduce_logsumexp(270utils.logminus(paths_p, paths_q), keepdims=True)271logminusqlogq = tf.reduce_logsumexp(272utils.logminus(paths_q, paths_q), keepdims=True)273logminusqlogp = tf.reduce_logsumexp(274utils.logminus(paths_q, paths_p), keepdims=True)275if sr_name == 'logentropy':276return (logp, logminusplogq)277elif sr_name == 'logreversekl':278return (logp, logq, logminusqlogq, logminusqlogp)279
280@parameterized.parameters([281('logentropy', semiring.LogEntropySemiring()),282('logreversekl', semiring.LogReverseKLSemiring()),283])284def testCTCSemiring(self, sr_name, sr):285loss = asr_loss.ctc_semiring(286sr=sr,287sr_inputs=(self.ctc_logits_p, self.ctc_logits_q),288output_labels=self.output_labels,289input_seq_len=self.input_seq_len,290output_seq_len=self.output_seq_len)291by_hand = self.ComputeLossByHand(sr_name, self.ctc_short_paths_p,292self.ctc_short_paths_q)293self.assertAllClose(loss, by_hand, atol=1e-37)294
295# Check that invalid losses are zero-ed out.296loss = asr_loss.ctc_semiring(297sr=sr,298sr_inputs=(self.ctc_logits_p, self.ctc_logits_q),299output_labels=self.output_labels,300input_seq_len=self.invalid_input_seq_len,301output_seq_len=self.invalid_output_seq_len)302for l in loss:303self.assertAllClose(l, np.array([0.0]), atol=1e-37)304
305# Check that the unused logits are masked out.306loss = asr_loss.ctc_semiring(307sr=sr,308sr_inputs=(self.ctc_logits_p, self.ctc_logits_q),309output_labels=self.output_labels,310input_seq_len=self.unused_input_seq_len,311output_seq_len=self.unused_output_seq_len)312by_hand = self.ComputeLossByHand(sr_name, self.ctc_long_paths_p,313self.ctc_long_paths_q)314self.assertAllClose(loss, by_hand, atol=1e-37)315
316@parameterized.parameters([317('logentropy', semiring.LogEntropySemiring()),318('logreversekl', semiring.LogReverseKLSemiring()),319])320def testRNNTSemiring(self, sr_name, sr):321loss = asr_loss.rnnt_semiring(322sr=sr,323s1_inputs=(self.rnnt_s1_logits_p, self.rnnt_s1_logits_q),324s2_inputs=(self.rnnt_s2_logits_p, self.rnnt_s2_logits_q),325s1_seq_len=self.s1_seq_len,326s2_seq_len=self.s2_seq_len)327by_hand = self.ComputeLossByHand(sr_name, self.rnnt_paths_p,328self.rnnt_paths_q)329self.assertAllClose(loss, by_hand, atol=1e-37)330
331# Check that invalid losses are zero-ed out.332loss_1 = asr_loss.rnnt_semiring(333sr=sr,334s1_inputs=(self.rnnt_s1_logits_p, self.rnnt_s1_logits_q),335s2_inputs=(self.rnnt_s2_logits_p, self.rnnt_s2_logits_q),336s1_seq_len=[0],337s2_seq_len=self.s2_seq_len)338loss_2 = asr_loss.rnnt_semiring(339sr=sr,340s1_inputs=(self.rnnt_s1_logits_p, self.rnnt_s1_logits_q),341s2_inputs=(self.rnnt_s2_logits_p, self.rnnt_s2_logits_q),342s1_seq_len=self.s1_seq_len,343s2_seq_len=[0])344zeros = tf.zeros_like(by_hand)345
346self.assertAllClose(loss_1, zeros)347self.assertAllClose(loss_2, zeros)348
349# Check that the unused logits are masked out.350rnnt_s1_logits_p = np.where(self.rnnt_s1_logits_p == 0.0, 1.23,351self.rnnt_s1_logits_p)352rnnt_s1_logits_q = np.where(self.rnnt_s1_logits_q == 0.0, 1.23,353self.rnnt_s1_logits_q)354rnnt_s2_logits_p = np.where(self.rnnt_s2_logits_p == 0.0, 1.23,355self.rnnt_s2_logits_p)356rnnt_s2_logits_q = np.where(self.rnnt_s2_logits_q == 0.0, 1.23,357self.rnnt_s2_logits_q)358
359loss = asr_loss.rnnt_semiring(360sr=sr,361s1_inputs=(rnnt_s1_logits_p, rnnt_s1_logits_q),362s2_inputs=(rnnt_s2_logits_p, rnnt_s2_logits_q),363s1_seq_len=self.s1_seq_len,364s2_seq_len=self.s2_seq_len)365by_hand = self.ComputeLossByHand(sr_name, self.rnnt_paths_p,366self.rnnt_paths_q)367self.assertAllClose(loss, by_hand, atol=1e-37)368
369
370if __name__ == '__main__':371tf.test.main()372