google-research
204 строки · 8.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for schema_guided_dst.baseline.data_utils."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import os
23
24from absl.testing import absltest
25
26from schema_guided_dst.baseline import config
27from schema_guided_dst.baseline import data_utils
28
29_VOCAB_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)),
30'test_data/bert_vocab.txt')
31_TEST_DATA_DIR = os.path.dirname(_VOCAB_FILE)
32_DO_LOWER_CASE = True
33_DATASET = 'train'
34
35
36class Dstc8DataProcessorTest(absltest.TestCase):
37"""Tests for Dstc8DataProcessor."""
38
39def setUp(self):
40self._processor = data_utils.Dstc8DataProcessor(
41dstc8_data_dir=_TEST_DATA_DIR,
42dataset_config=config.DatasetConfig(
43file_ranges={
44'train': range(1),
45'dev': None,
46'test': None
47},
48max_num_cat_slot=6,
49max_num_noncat_slot=6,
50max_num_value_per_cat_slot=4,
51max_num_intent=2),
52vocab_file=_VOCAB_FILE,
53do_lower_case=_DO_LOWER_CASE)
54super(Dstc8DataProcessorTest, self).setUp()
55
56def test_tokenizer(self):
57# Test normal sentence.
58test_utt_1 = 'Watch, Hellboy?'
59utt_1_tokens, utt_1_aligns, utt_1_inv_alignments = (
60self._processor._tokenize(test_utt_1))
61expected_utt_1_tokens = ['watch', ',', 'hell', '##boy', '?']
62expected_utt_1_aligns = {0: 0, 4: 0, 5: 1, 7: 2, 13: 3, 14: 4}
63expected_utt_1_inv_alignments = [(0, 4), (5, 5), (7, 13), (7, 13), (14, 14)]
64self.assertEqual(utt_1_tokens, expected_utt_1_tokens)
65self.assertEqual(utt_1_aligns, expected_utt_1_aligns)
66self.assertEqual(utt_1_inv_alignments, expected_utt_1_inv_alignments)
67
68# Test extra spaces in the utterance.
69test_utt_2 = 'Extra , spaces'
70utt_2_tokens, utt_2_aligns, utt_2_inv_alignments = (
71self._processor._tokenize(test_utt_2))
72expected_utt_1_inv_alignments = [(0, 4), (5, 5), (7, 13), (7, 13), (14, 14)]
73self.assertEqual(utt_2_tokens, ['extra', ',', 'spaces'])
74self.assertEqual(utt_2_aligns, {0: 0, 4: 0, 7: 1, 10: 2, 15: 2})
75self.assertEqual(utt_2_inv_alignments, [(0, 4), (7, 7), (10, 15)])
76
77# Test # appearing in the string.
78test_utt_3 = 'Extra## ##abc'
79utt_3_tokens, utt_3_aligns, utt_3_inv_alignments = (
80self._processor._tokenize(test_utt_3))
81self.assertEqual(utt_3_tokens,
82['extra', '#', '#', '#', '#', 'a', '##b', '##c'])
83self.assertEqual(utt_3_aligns, {
840: 0,
854: 0,
865: 1,
876: 2,
888: 3,
899: 4,
9010: 5,
9112: 7
92})
93self.assertEqual(utt_3_inv_alignments, [(0, 4), (5, 5), (6, 6), (8, 8),
94(9, 9), (10, 12), (10, 12),
95(10, 12)])
96
97def test_get_dialog_examples(self):
98examples = self._processor.get_dialog_examples(_DATASET)
99# Check that the summary of all the turns are correct.
100expected_summaries = [
101{
102'utt_tok_mask_pairs': [('[CLS]', 0), ('[SEP]', 0),
103('i', 1), ("'", 1), ('m', 1), ('looking', 1),
104('for', 1), ('apartments', 1), ('.', 1),
105('[SEP]', 1)],
106'utt_len': 10,
107'num_categorical_slots': 4,
108'num_categorical_slot_values': [2, 4, 4, 2, 0, 0],
109'num_noncategorical_slots': 3,
110'service_name': 'Homes_1',
111'active_intent': 'FindApartment',
112'slot_values_in_state': {}
113},
114{
115'utt_tok_mask_pairs': [('[CLS]', 0), ('which', 0), ('area', 0),
116('are', 0), ('you', 0), ('looking', 0),
117('in', 0), ('?', 0), ('[SEP]', 0), ('i', 1),
118('want', 1), ('an', 1), ('apartment', 1),
119('in', 1), ('sa', 1), ('##n', 1), ('j', 1),
120('##ose', 1), ('.', 1), ('[SEP]', 1)],
121'utt_len': 20,
122'num_categorical_slots': 4,
123'num_categorical_slot_values': [2, 4, 4, 2, 0, 0],
124'num_noncategorical_slots': 3,
125'service_name': 'Homes_1',
126'active_intent': 'FindApartment',
127'slot_values_in_state': {
128'area': 'san jose'
129}
130},
131{
132'utt_tok_mask_pairs': [('[CLS]', 0), ('how', 0), ('many', 0),
133('bedrooms', 0), ('do', 0), ('you', 0),
134('want', 0), ('?', 0), ('[SEP]', 0),
135('2', 1), ('bedrooms', 1), (',', 1),
136('please', 1), ('.', 1), ('[SEP]', 1)],
137'utt_len': 15,
138'num_categorical_slots': 4,
139'num_categorical_slot_values': [2, 4, 4, 2, 0, 0],
140'num_noncategorical_slots': 3,
141'service_name': 'Homes_1',
142'active_intent': 'FindApartment',
143'slot_values_in_state': {
144'number_of_beds': '2'
145}
146},
147{
148'utt_tok_mask_pairs': [
149('[CLS]', 0), ('there', 0), ("'", 0), ('s', 0), ('a', 0),
150('nice', 0), ('property', 0), ('called', 0), ('a', 0),
151('##ege', 0), ('##na', 0), ('at', 0), ('129', 0), ('##0', 0),
152('sa', 0), ('##n', 0), ('to', 0), ('##mas', 0), ('a', 0),
153('##quin', 0), ('##o', 0), ('road', 0), ('.', 0), ('it', 0),
154('has', 0), ('2', 0), ('bedrooms', 0), (',', 0), ('1', 0),
155('bath', 0), (',', 0), ('and', 0), ('rent', 0), ('##s', 0),
156('for', 0), ('$', 0), ('2', 0), (',', 0), ('650', 0), ('a', 0),
157('month', 0), ('.', 0), ('[SEP]', 0), ('can', 1), ('you', 1),
158('find', 1), ('me', 1), ('a', 1), ('three', 1), ('bedroom', 1),
159('apartment', 1), ('in', 1), ('liver', 1), ('##more', 1),
160('?', 1), ('[SEP]', 1)
161],
162'utt_len': 56,
163'num_categorical_slots': 4,
164'num_categorical_slot_values': [2, 4, 4, 2, 0, 0],
165'num_noncategorical_slots': 3,
166'service_name': 'Homes_1',
167'active_intent': 'FindApartment',
168'slot_values_in_state': {
169'number_of_beds': '3',
170'area': 'livermore'
171}
172},
173{
174'utt_tok_mask_pairs': [
175('[CLS]', 0), ('there', 0), ("'", 0), ('s', 0), ('a', 0),
176('##cacia', 0), ('capital', 0), ('co', 0), ('##r', 0), ('-', 0),
177('iron', 0), ('##wood', 0), ('a', 0), ('##p', 0), ('at', 0),
178('56', 0), ('##43', 0), ('ch', 0), ('##ar', 0), ('##lot', 0),
179('##te', 0), ('way', 0), ('.', 0), ('it', 0), ('has', 0),
180('3', 0), ('bedrooms', 0), (',', 0), ('3', 0), ('baths', 0),
181(',', 0), ('and', 0), ('rent', 0), ('##s', 0), ('for', 0),
182('$', 0), ('4', 0), (',', 0), ('05', 0), ('##0', 0), ('a', 0),
183('month', 0), ('.', 0), ('[SEP]', 0), ('that', 1), ('one', 1),
184('sounds', 1), ('good', 1), ('.', 1), ('thanks', 1), (',', 1),
185('that', 1), ("'", 1), ('s', 1), ('all', 1), ('i', 1),
186('need', 1), ('.', 1), ('[SEP]', 1)
187],
188'utt_len': 59,
189'num_categorical_slots': 4,
190'num_categorical_slot_values': [2, 4, 4, 2, 0, 0],
191'num_noncategorical_slots': 3,
192'service_name': 'Homes_1',
193'active_intent': 'FindApartment',
194'slot_values_in_state': {
195'property_name': 'acacia capital cor - ironwood ap'
196}
197},
198]
199for example, gold in zip(examples, expected_summaries):
200self.assertEqual(example.readable_summary, gold)
201
202
203if __name__ == '__main__':
204absltest.main()
205