google-research
113 строк · 4.3 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for feature_converters."""
17from seqio import test_utils18import tensorflow.compat.v2 as tf19
20from kl_guided_sampling import feature_converters21
22
23tf.compat.v1.enable_eager_execution()24
25assert_dataset = test_utils.assert_dataset26create_default_dataset = test_utils.create_default_dataset27
28
29class ContextualEncDecFeatureConverterTest(tf.test.TestCase):30
31def test_encoder_decoder_unpacked_all_positive(self):32x = [{"inputs": [9, 4, 3, 8, 1], "targets": [3, 9, 4, 1]}]33ds = create_default_dataset(x)34task_feature_lengths = {"inputs": 7, "targets": 5}35
36converter = feature_converters.ContextualEncDecFeatureConverter(pack=False)37converted_ds = converter(ds, task_feature_lengths)38
39expected = {40"encoder_input_tokens": [9, 4, 3, 8, 1, 0, 0],41"encoder_input_tokens_wo": [9, 4, 3, 8, 1, 0, 0],42"decoder_target_tokens": [3, 9, 4, 1, 0],43"decoder_input_tokens": [0, 3, 9, 4, 1],44"decoder_loss_weights": [1, 1, 1, 1, 0],45}46assert_dataset(converted_ds, expected)47
48def test_encoder_decoder_unpacked_some_negatives(self):49x = [{"inputs": [-7, 8, 5, 1], "targets": [3, 9, 1]}]50ds = create_default_dataset(x)51task_feature_lengths = {"inputs": 10, "targets": 7}52
53converter = feature_converters.ContextualEncDecFeatureConverter(pack=False)54converted_ds = converter(ds, task_feature_lengths)55
56expected = {57"encoder_input_tokens": [7, 8, 5, 1, 0, 0, 0, 0, 0, 0],58"encoder_input_tokens_wo": [8, 5, 1, 0, 0, 0, 0, 0, 0, 0],59"decoder_target_tokens": [3, 9, 1, 0, 0, 0, 0],60"decoder_input_tokens": [0, 3, 9, 1, 0, 0, 0],61"decoder_loss_weights": [1, 1, 1, 0, 0, 0, 0],62}63assert_dataset(converted_ds, expected)64
65
66class ContextualPrefixLMFeatureConverter(tf.test.TestCase):67
68def test_prefix_lm_unpacked_all_positive(self):69x = [{"inputs": [9, 4, 6, 1], "targets": [3, 9, 1]}]70ds = create_default_dataset(x)71
72task_feature_lengths = {"inputs": 5, "targets": 4}73converter = feature_converters.ContextualPrefixLMFeatureConverter(74pack=False)75converted_ds = converter(ds, task_feature_lengths)76
77expected = {78"decoder_target_tokens": [9, 4, 6, 1, 3, 9, 1, 0, 0],79"decoder_input_tokens": [0, 9, 4, 6, 1, 3, 9, 1, 0],80"decoder_loss_weights": [0, 0, 0, 0, 1, 1, 1, 0, 0],81"decoder_causal_attention": [1, 1, 1, 1, 1, 0, 0, 0, 0],82"decoder_target_tokens_wo": [9, 4, 6, 1, 3, 9, 1, 0, 0],83"decoder_input_tokens_wo": [0, 9, 4, 6, 1, 3, 9, 1, 0],84"decoder_loss_weights_wo": [0, 0, 0, 0, 1, 1, 1, 0, 0],85"decoder_causal_attention_wo": [1, 1, 1, 1, 1, 0, 0, 0, 0],86}87assert_dataset(converted_ds, expected)88
89def test_prefix_lm_unpacked_some_negatives(self):90x = [{"inputs": [-9, 4, 6, 1], "targets": [3, 9, 1]}]91ds = create_default_dataset(x)92
93task_feature_lengths = {"inputs": 10, "targets": 4}94converter = feature_converters.ContextualPrefixLMFeatureConverter(95pack=False)96converted_ds = converter(ds, task_feature_lengths)97
98expected = {99"decoder_target_tokens": [9, 4, 6, 1, 3, 9, 1, 0, 0, 0, 0, 0, 0, 0],100"decoder_input_tokens": [0, 9, 4, 6, 1, 3, 9, 1, 0, 0, 0, 0, 0, 0],101"decoder_loss_weights": [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],102"decoder_causal_attention": [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],103"decoder_target_tokens_wo": [4, 6, 1, 3, 9, 1, 0, 0, 0, 0, 0, 0, 0, 0],104"decoder_input_tokens_wo": [0, 4, 6, 1, 3, 9, 1, 0, 0, 0, 0, 0, 0, 0],105"decoder_loss_weights_wo": [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],106"decoder_causal_attention_wo":107[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]108}109assert_dataset(converted_ds, expected)110
111
112if __name__ == '__main__':113tf.test.main()114