google-research
137 строк · 5.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Tests for metrics.py libraries.
17
18There are two kinds of tests being run in this file. The first one is comparison
19of metrics calculated for oracle predictions. The second is comparison of
20metrics calculated for a known prediction with the known ground truth values.
21"""
22
23from __future__ import absolute_import
24from __future__ import division
25from __future__ import print_function
26
27import json
28import os
29
30from absl.testing import absltest
31
32from schema_guided_dst import metrics
33
34ACCURACY_METRICS = [
35metrics.AVERAGE_GOAL_ACCURACY,
36metrics.AVERAGE_CAT_ACCURACY,
37metrics.AVERAGE_NONCAT_ACCURACY,
38metrics.JOINT_GOAL_ACCURACY,
39metrics.JOINT_CAT_ACCURACY,
40metrics.JOINT_NONCAT_ACCURACY,
41]
42THIS_DIR = os.path.dirname(os.path.abspath(__file__))
43
44
45class MetricsTest(absltest.TestCase):
46
47def setUp(self):
48super(MetricsTest, self).setUp()
49data_file = os.path.join(THIS_DIR, "test_data", "metrics_test_refdata.json")
50with open(data_file) as f:
51test_data = json.load(f)
52self.frame_ref = test_data["frame_ref"]
53self.frame_hyp = test_data["frame_hyp"]
54self.correct_metrics = test_data["metrics_hyp_fuzzy_match"]
55self.correct_metrics_exact_match = test_data["metrics_hyp_exact_match"]
56self.utterance = test_data["utterance"]
57
58schema_file = os.path.join(THIS_DIR, "test_data",
59"metrics_test_refschema.json")
60with open(schema_file) as f:
61self.schema = json.load(f)
62
63def _assert_dicts_almost_equal(self, ref_dict, other_dict):
64self.assertCountEqual(ref_dict.keys(), other_dict.keys())
65for metric in ref_dict.keys():
66self.assertAlmostEqual(ref_dict[metric], other_dict[metric])
67
68def test_active_intent_accuracy(self):
69# (1) Test on oracle frame.
70intent_acc_oracle = metrics.get_active_intent_accuracy(
71self.frame_ref, self.frame_ref)
72self.assertAlmostEqual(1.0, intent_acc_oracle)
73
74# (2) Test on a previously known frame.
75intent_acc_hyp = metrics.get_active_intent_accuracy(self.frame_ref,
76self.frame_hyp)
77self.assertAlmostEqual(self.correct_metrics[metrics.ACTIVE_INTENT_ACCURACY],
78intent_acc_hyp)
79
80def test_slot_tagging_f1(self):
81# (1) Test on oracle frame.
82slot_tagging_f1_oracle = metrics.get_slot_tagging_f1(
83self.frame_ref, self.frame_ref, self.utterance, self.schema)
84# Ground truth values for oracle prediction are all 1.0.
85self._assert_dicts_almost_equal(
86{k: 1.0 for k in self.correct_metrics[metrics.SLOT_TAGGING_F1]},
87slot_tagging_f1_oracle._asdict())
88
89# (2) Test on a previously known frame.
90slot_tagging_f1_hyp = metrics.get_slot_tagging_f1(self.frame_ref,
91self.frame_hyp,
92self.utterance,
93self.schema)
94self._assert_dicts_almost_equal(
95self.correct_metrics[metrics.SLOT_TAGGING_F1],
96slot_tagging_f1_hyp._asdict())
97
98def test_requested_slots_f1(self):
99# (1) Test on oracle frame.
100requestable_slots_f1_oracle = metrics.get_requested_slots_f1(
101self.frame_ref, self.frame_ref)
102# Ground truth values for oracle prediction are all 1.0.
103self._assert_dicts_almost_equal(
104{k: 1.0 for k in self.correct_metrics[metrics.REQUESTED_SLOTS_F1]},
105requestable_slots_f1_oracle._asdict())
106
107# (2) Test on a previously known frame.
108requested_slots_f1_hyp = metrics.get_requested_slots_f1(
109self.frame_ref, self.frame_hyp)
110self._assert_dicts_almost_equal(
111self.correct_metrics[metrics.REQUESTED_SLOTS_F1],
112requested_slots_f1_hyp._asdict())
113
114def test_average_and_joint_goal_accuracy(self):
115# (1) Test on oracle frame.
116goal_accuracy_oracle_strict = metrics.get_average_and_joint_goal_accuracy(
117self.frame_ref, self.frame_ref, self.schema, False)
118# Ground truth values for oracle prediction are all 1.0.
119self._assert_dicts_almost_equal({k: 1.0 for k in ACCURACY_METRICS},
120goal_accuracy_oracle_strict)
121
122# (2) Test on a previously known frame.
123goal_accuracy_hyp = metrics.get_average_and_joint_goal_accuracy(
124self.frame_ref, self.frame_hyp, self.schema, True)
125self._assert_dicts_almost_equal(
126{k: self.correct_metrics[k] for k in ACCURACY_METRICS},
127goal_accuracy_hyp)
128# (3) Test using strict string matching for non-categorical slot values.
129goal_accuracy_hyp = metrics.get_average_and_joint_goal_accuracy(
130self.frame_ref, self.frame_hyp, self.schema, False)
131self._assert_dicts_almost_equal(
132{k: self.correct_metrics_exact_match[k] for k in ACCURACY_METRICS},
133goal_accuracy_hyp)
134
135
136if __name__ == "__main__":
137absltest.main()
138