google-research

Форк
0
137 строк · 5.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Tests for metrics.py libraries.
17

18
There are two kinds of tests being run in this file. The first one is comparison
19
of metrics calculated for oracle predictions. The second is comparison of
20
metrics calculated for a known prediction with the known ground truth values.
21
"""
22

23
from __future__ import absolute_import
24
from __future__ import division
25
from __future__ import print_function
26

27
import json
28
import os
29

30
from absl.testing import absltest
31

32
from schema_guided_dst import metrics
33

34
ACCURACY_METRICS = [
35
    metrics.AVERAGE_GOAL_ACCURACY,
36
    metrics.AVERAGE_CAT_ACCURACY,
37
    metrics.AVERAGE_NONCAT_ACCURACY,
38
    metrics.JOINT_GOAL_ACCURACY,
39
    metrics.JOINT_CAT_ACCURACY,
40
    metrics.JOINT_NONCAT_ACCURACY,
41
]
42
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
43

44

45
class MetricsTest(absltest.TestCase):
46

47
  def setUp(self):
48
    super(MetricsTest, self).setUp()
49
    data_file = os.path.join(THIS_DIR, "test_data", "metrics_test_refdata.json")
50
    with open(data_file) as f:
51
      test_data = json.load(f)
52
    self.frame_ref = test_data["frame_ref"]
53
    self.frame_hyp = test_data["frame_hyp"]
54
    self.correct_metrics = test_data["metrics_hyp_fuzzy_match"]
55
    self.correct_metrics_exact_match = test_data["metrics_hyp_exact_match"]
56
    self.utterance = test_data["utterance"]
57

58
    schema_file = os.path.join(THIS_DIR, "test_data",
59
                               "metrics_test_refschema.json")
60
    with open(schema_file) as f:
61
      self.schema = json.load(f)
62

63
  def _assert_dicts_almost_equal(self, ref_dict, other_dict):
64
    self.assertCountEqual(ref_dict.keys(), other_dict.keys())
65
    for metric in ref_dict.keys():
66
      self.assertAlmostEqual(ref_dict[metric], other_dict[metric])
67

68
  def test_active_intent_accuracy(self):
69
    # (1) Test on oracle frame.
70
    intent_acc_oracle = metrics.get_active_intent_accuracy(
71
        self.frame_ref, self.frame_ref)
72
    self.assertAlmostEqual(1.0, intent_acc_oracle)
73

74
    # (2) Test on a previously known frame.
75
    intent_acc_hyp = metrics.get_active_intent_accuracy(self.frame_ref,
76
                                                        self.frame_hyp)
77
    self.assertAlmostEqual(self.correct_metrics[metrics.ACTIVE_INTENT_ACCURACY],
78
                           intent_acc_hyp)
79

80
  def test_slot_tagging_f1(self):
81
    # (1) Test on oracle frame.
82
    slot_tagging_f1_oracle = metrics.get_slot_tagging_f1(
83
        self.frame_ref, self.frame_ref, self.utterance, self.schema)
84
    # Ground truth values for oracle prediction are all 1.0.
85
    self._assert_dicts_almost_equal(
86
        {k: 1.0 for k in self.correct_metrics[metrics.SLOT_TAGGING_F1]},
87
        slot_tagging_f1_oracle._asdict())
88

89
    # (2) Test on a previously known frame.
90
    slot_tagging_f1_hyp = metrics.get_slot_tagging_f1(self.frame_ref,
91
                                                      self.frame_hyp,
92
                                                      self.utterance,
93
                                                      self.schema)
94
    self._assert_dicts_almost_equal(
95
        self.correct_metrics[metrics.SLOT_TAGGING_F1],
96
        slot_tagging_f1_hyp._asdict())
97

98
  def test_requested_slots_f1(self):
99
    # (1) Test on oracle frame.
100
    requestable_slots_f1_oracle = metrics.get_requested_slots_f1(
101
        self.frame_ref, self.frame_ref)
102
    # Ground truth values for oracle prediction are all 1.0.
103
    self._assert_dicts_almost_equal(
104
        {k: 1.0 for k in self.correct_metrics[metrics.REQUESTED_SLOTS_F1]},
105
        requestable_slots_f1_oracle._asdict())
106

107
    # (2) Test on a previously known frame.
108
    requested_slots_f1_hyp = metrics.get_requested_slots_f1(
109
        self.frame_ref, self.frame_hyp)
110
    self._assert_dicts_almost_equal(
111
        self.correct_metrics[metrics.REQUESTED_SLOTS_F1],
112
        requested_slots_f1_hyp._asdict())
113

114
  def test_average_and_joint_goal_accuracy(self):
115
    # (1) Test on oracle frame.
116
    goal_accuracy_oracle_strict = metrics.get_average_and_joint_goal_accuracy(
117
        self.frame_ref, self.frame_ref, self.schema, False)
118
    # Ground truth values for oracle prediction are all 1.0.
119
    self._assert_dicts_almost_equal({k: 1.0 for k in ACCURACY_METRICS},
120
                                    goal_accuracy_oracle_strict)
121

122
    # (2) Test on a previously known frame.
123
    goal_accuracy_hyp = metrics.get_average_and_joint_goal_accuracy(
124
        self.frame_ref, self.frame_hyp, self.schema, True)
125
    self._assert_dicts_almost_equal(
126
        {k: self.correct_metrics[k] for k in ACCURACY_METRICS},
127
        goal_accuracy_hyp)
128
    # (3) Test using strict string matching for non-categorical slot values.
129
    goal_accuracy_hyp = metrics.get_average_and_joint_goal_accuracy(
130
        self.frame_ref, self.frame_hyp, self.schema, False)
131
    self._assert_dicts_almost_equal(
132
        {k: self.correct_metrics_exact_match[k] for k in ACCURACY_METRICS},
133
        goal_accuracy_hyp)
134

135

136
if __name__ == "__main__":
137
  absltest.main()
138

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.