google-research

Форк
0
/
feature_converters_test.py 
113 строк · 4.3 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for feature_converters."""
17
from seqio import test_utils
18
import tensorflow.compat.v2 as tf
19

20
from kl_guided_sampling import feature_converters
21

22

23
tf.compat.v1.enable_eager_execution()
24

25
assert_dataset = test_utils.assert_dataset
26
create_default_dataset = test_utils.create_default_dataset
27

28

29
class ContextualEncDecFeatureConverterTest(tf.test.TestCase):
30

31
  def test_encoder_decoder_unpacked_all_positive(self):
32
    x = [{"inputs": [9, 4, 3, 8, 1], "targets": [3, 9, 4, 1]}]
33
    ds = create_default_dataset(x)
34
    task_feature_lengths = {"inputs": 7, "targets": 5}
35

36
    converter = feature_converters.ContextualEncDecFeatureConverter(pack=False)
37
    converted_ds = converter(ds, task_feature_lengths)
38

39
    expected = {
40
        "encoder_input_tokens": [9, 4, 3, 8, 1, 0, 0],
41
        "encoder_input_tokens_wo": [9, 4, 3, 8, 1, 0, 0],
42
        "decoder_target_tokens": [3, 9, 4, 1, 0],
43
        "decoder_input_tokens": [0, 3, 9, 4, 1],
44
        "decoder_loss_weights": [1, 1, 1, 1, 0],
45
    }
46
    assert_dataset(converted_ds, expected)
47

48
  def test_encoder_decoder_unpacked_some_negatives(self):
49
    x = [{"inputs": [-7, 8, 5, 1], "targets": [3, 9, 1]}]
50
    ds = create_default_dataset(x)
51
    task_feature_lengths = {"inputs": 10, "targets": 7}
52

53
    converter = feature_converters.ContextualEncDecFeatureConverter(pack=False)
54
    converted_ds = converter(ds, task_feature_lengths)
55

56
    expected = {
57
        "encoder_input_tokens": [7, 8, 5, 1, 0, 0, 0, 0, 0, 0],
58
        "encoder_input_tokens_wo": [8, 5, 1, 0, 0, 0, 0, 0, 0, 0],
59
        "decoder_target_tokens": [3, 9, 1, 0, 0, 0, 0],
60
        "decoder_input_tokens": [0, 3, 9, 1, 0, 0, 0],
61
        "decoder_loss_weights": [1, 1, 1, 0, 0, 0, 0],
62
    }
63
    assert_dataset(converted_ds, expected)
64

65

66
class ContextualPrefixLMFeatureConverter(tf.test.TestCase):
67

68
  def test_prefix_lm_unpacked_all_positive(self):
69
    x = [{"inputs": [9, 4, 6, 1], "targets": [3, 9, 1]}]
70
    ds = create_default_dataset(x)
71

72
    task_feature_lengths = {"inputs": 5, "targets": 4}
73
    converter = feature_converters.ContextualPrefixLMFeatureConverter(
74
        pack=False)
75
    converted_ds = converter(ds, task_feature_lengths)
76

77
    expected = {
78
        "decoder_target_tokens": [9, 4, 6, 1, 3, 9, 1, 0, 0],
79
        "decoder_input_tokens": [0, 9, 4, 6, 1, 3, 9, 1, 0],
80
        "decoder_loss_weights": [0, 0, 0, 0, 1, 1, 1, 0, 0],
81
        "decoder_causal_attention": [1, 1, 1, 1, 1, 0, 0, 0, 0],
82
        "decoder_target_tokens_wo": [9, 4, 6, 1, 3, 9, 1, 0, 0],
83
        "decoder_input_tokens_wo": [0, 9, 4, 6, 1, 3, 9, 1, 0],
84
        "decoder_loss_weights_wo": [0, 0, 0, 0, 1, 1, 1, 0, 0],
85
        "decoder_causal_attention_wo": [1, 1, 1, 1, 1, 0, 0, 0, 0],
86
    }
87
    assert_dataset(converted_ds, expected)
88

89
  def test_prefix_lm_unpacked_some_negatives(self):
90
    x = [{"inputs": [-9, 4, 6, 1], "targets": [3, 9, 1]}]
91
    ds = create_default_dataset(x)
92

93
    task_feature_lengths = {"inputs": 10, "targets": 4}
94
    converter = feature_converters.ContextualPrefixLMFeatureConverter(
95
        pack=False)
96
    converted_ds = converter(ds, task_feature_lengths)
97

98
    expected = {
99
        "decoder_target_tokens": [9, 4, 6, 1, 3, 9, 1, 0, 0, 0, 0, 0, 0, 0],
100
        "decoder_input_tokens": [0, 9, 4, 6, 1, 3, 9, 1, 0, 0, 0, 0, 0, 0],
101
        "decoder_loss_weights": [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
102
        "decoder_causal_attention": [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
103
        "decoder_target_tokens_wo": [4, 6, 1, 3, 9, 1, 0, 0, 0, 0, 0, 0, 0, 0],
104
        "decoder_input_tokens_wo": [0, 4, 6, 1, 3, 9, 1, 0, 0, 0, 0, 0, 0, 0],
105
        "decoder_loss_weights_wo": [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
106
        "decoder_causal_attention_wo":
107
          [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
108
    }
109
    assert_dataset(converted_ds, expected)
110

111

112
if __name__ == '__main__':
113
  tf.test.main()
114

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.