google-research

Форк
0
295 строк · 10.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Evaluation metrics for Schema-guided dialogue.
17

18
This library provides functions for calculating the evaluation metrics for a
19
single dialogue. The following metrics are defined:
20

21
(1) Active intent accuracy: The fraction of user turns for which the active
22
  intent has been correctly predicted.
23
(2) Slot tagging F1: The macro-averaged F1 score for tagging slot values for
24
  non-categorical slots. This metric is optional to report in the final paper
25
  if participants decide not to use slot tagging.
26
(3) Requested slots F1: The macro-averaged F1 score for requested slots over the
27
  turns. For a turn, if there are no requested slots in both the ground truth
28
  and the prediction, that turn is skipped. The reported number is the average
29
  F1 score for all un-skipped user turns. This metric is optional to report in
30
  the final paper.
31
(4) Average goal accuracy: For each turn, participants must predict a single
32
  value for each slot present in the dialogue state. The slots which have a
33
  non-empty assignment in the ground truth dialogue state are only considered.
34
  This is the average accuracy of predicting the value of a slot correctly. A
35
  fuzzy matching based score is used for non-categorical slots.
36
(5) Joint goal accuracy: This is the average accuracy of predicting all slot
37
  assignments for a turn correctly. A fuzzy matching based score is used for
38
  non-categorical slots. This is the primary evaluation metric used for ranking
39
  submissions. More details to follow with the evaluation script.
40
"""
41

42
from __future__ import absolute_import
43
from __future__ import division
44
from __future__ import print_function
45

46
import collections
47
from fuzzywuzzy import fuzz
48
import numpy as np
49

50
F1Scores = collections.namedtuple("F1Scores", ["f1", "precision", "recall"])
51

52
# Evaluation and other relevant metrics for DSTC8 Schema-guided DST.
53
# (1) Active intent accuracy.
54
ACTIVE_INTENT_ACCURACY = "active_intent_accuracy"
55
# (2) Slot tagging F1.
56
SLOT_TAGGING_F1 = "slot_tagging_f1"
57
SLOT_TAGGING_PRECISION = "slot_tagging_precision"
58
SLOT_TAGGING_RECALL = "slot_tagging_recall"
59
# (3) Requested slots F1.
60
REQUESTED_SLOTS_F1 = "requested_slots_f1"
61
REQUESTED_SLOTS_PRECISION = "requested_slots_precision"
62
REQUESTED_SLOTS_RECALL = "requested_slots_recall"
63
# (4) Average goal accuracy.
64
AVERAGE_GOAL_ACCURACY = "average_goal_accuracy"
65
AVERAGE_CAT_ACCURACY = "average_cat_accuracy"
66
AVERAGE_NONCAT_ACCURACY = "average_noncat_accuracy"
67
# (5) Joint goal accuracy.
68
JOINT_GOAL_ACCURACY = "joint_goal_accuracy"
69
JOINT_CAT_ACCURACY = "joint_cat_accuracy"
70
JOINT_NONCAT_ACCURACY = "joint_noncat_accuracy"
71

72
NAN_VAL = "NA"
73

74

75
def compute_f1(list_ref, list_hyp):
76
  """Compute F1 score from reference (grouth truth) list and hypothesis list.
77

78
  Args:
79
    list_ref: List of true elements.
80
    list_hyp: List of postive (retrieved) elements.
81

82
  Returns:
83
    A F1Scores object containing F1, precision, and recall scores.
84
  """
85

86
  ref = collections.Counter(list_ref)
87
  hyp = collections.Counter(list_hyp)
88
  true = sum(ref.values())
89
  positive = sum(hyp.values())
90
  true_positive = sum((ref & hyp).values())
91
  precision = float(true_positive) / positive if positive else 1.0
92
  recall = float(true_positive) / true if true else 1.0
93
  if precision + recall > 0.0:
94
    f1 = 2.0 * precision * recall / (precision + recall)
95
  else:  # The F1-score is defined to be 0 if both precision and recall are 0.
96
    f1 = 0.0
97

98
  return F1Scores(f1=f1, precision=precision, recall=recall)
99

100

101
def fuzzy_string_match(str_ref, str_hyp):
102
  """Returns fuzzy string similarity score in range [0.0, 1.0]."""
103

104
  # The higher the score, the higher the similarity between the two strings.
105
  return fuzz.token_sort_ratio(str_ref, str_hyp) / 100.0
106

107

108
def noncat_slot_value_match(str_ref_list, str_hyp, use_fuzzy_match):
109
  """Calculate non-categorical slots correctness.
110

111
  Args:
112
    str_ref_list: a list of reference strings.
113
    str_hyp: the hypothesis string.
114
    use_fuzzy_match: whether to use fuzzy string matching.
115

116
  Returns:
117
    score: The highest fuzzy string match score of the references and hypotheis.
118
  """
119
  score = 0.0
120
  for str_ref in str_ref_list:
121
    if not use_fuzzy_match:
122
      match_score = float(str_ref == str_hyp)
123
    else:
124
      match_score = fuzzy_string_match(str_ref, str_hyp)
125
    score = max(score, match_score)
126
  return score
127

128

129
def compare_slot_values(slot_values_ref, slot_values_hyp, service,
130
                        use_fuzzy_match):
131
  """Compare and get correctness of goal state's slot_values.
132

133
  Args:
134
    slot_values_ref: goal state slot_values from reference (ground truth).
135
    slot_values_hyp: goal state slot_values from hypothesis (prediction).
136
    service: a service data structure in the schema. We use it to obtain the
137
      list of slots in the service and infer whether a slot is categorical.
138
    use_fuzzy_match: whether to use fuzzy string matching for non-categorical
139
      slot values.
140

141
  Returns:
142
    (list_cor, slot_active, slot_cat)
143
    list_cor: list of corectness scores, each corresponding to one slot in the
144
        service. The score is a float either 0.0 or 1.0 for categorical slot,
145
        and in range [0.0, 1.0] for non-categorical slot.
146
    slot_active: list indicating whether the element in list_cor corresponds to
147
        an active ground-truth slot.
148
    slot_cat: list indicating whether the element in list_cor corresponds to a
149
        categorical slot.
150
  """
151
  list_cor = []
152
  slot_active = []
153
  slot_cat = []
154

155
  for slot in service["slots"]:
156
    slot_name = slot["name"]
157
    slot_cat.append(slot["is_categorical"])
158

159
    if slot_name in slot_values_ref:  # REF=active
160
      slot_active.append(True)
161
      if slot_name in slot_values_hyp:  # HYP=active, apply matching
162
        value_ref_list = slot_values_ref[slot_name]
163
        value_hyp = slot_values_hyp[slot_name][0]
164
        if slot["is_categorical"]:
165
          cor = float(value_ref_list[0].lower() == value_hyp.lower())
166
        else:
167
          cor = noncat_slot_value_match(value_ref_list, value_hyp,
168
                                        use_fuzzy_match)
169

170
        list_cor.append(cor)
171
      else:  # HYP=off
172
        list_cor.append(0.0)
173
    else:  # REF=off
174
      slot_active.append(False)
175
      if slot_name in slot_values_hyp:  # HYP=active
176
        list_cor.append(0.0)
177
      else:  # HYP=off
178
        list_cor.append(1.0)
179

180
  assert len(list_cor) == len(service["slots"])
181
  assert len(slot_active) == len(service["slots"])
182
  assert len(slot_cat) == len(service["slots"])
183
  return list_cor, slot_active, slot_cat
184

185

186
def get_active_intent_accuracy(frame_ref, frame_hyp):
187
  """Get active intent accuracy of a frame.
188

189
  Args:
190
    frame_ref: single semantic frame from reference (ground truth) file.
191
    frame_hyp: single semantic frame from hypothesis (prediction) file.
192

193
  Returns:
194
    1.0 if the intent prediction is correct, otherwise 0.0.
195
  """
196
  return float(frame_ref["state"]["active_intent"].lower() == frame_hyp["state"]
197
               ["active_intent"].lower())
198

199

200
def get_slot_tagging_f1(frame_ref, frame_hyp, utt, service):
201
  """Get slot tagging (non-categorical slots only) F1 scores of a frame.
202

203
  Args:
204
    frame_ref: single semantic frame from reference (ground truth) file.
205
    frame_hyp: single semantic frame from hypothesis (prediction) file.
206
    utt: user utterance. Slot tagging annotations are the character positions in
207
      the utterance.
208
    service: a service data structure in the schema. We use it to infer whether
209
      a slot is non-categorical.
210

211
  Returns:
212
    A F1Scores object containing F1, precision, and recall scores.
213
  """
214

215
  list_noncat_slots = [
216
      s["name"] for s in service["slots"] if not s["is_categorical"]
217
  ]
218
  if "slots" not in frame_hyp:
219
    return None
220
  else:
221
    list_ref = [(s["slot"], utt[s["start"]:s["exclusive_end"]])
222
                for s in frame_ref["slots"]
223
                if s["slot"] in list_noncat_slots]
224
    list_hyp = [(s["slot"], utt[s["start"]:s["exclusive_end"]])
225
                for s in frame_hyp["slots"]
226
                if s["slot"] in list_noncat_slots]
227
    return compute_f1(list_ref, list_hyp)
228

229

230
def get_requested_slots_f1(frame_ref, frame_hyp):
231
  """Get requested slots F1 scores of a frame.
232

233
  Args:
234
    frame_ref: single semantic frame from reference (ground truth) file.
235
    frame_hyp: single semantic frame from hypothesis (prediction) file.
236

237
  Returns:
238
    A F1Scores object containing F1, precision, and recall scores.
239
  """
240
  return compute_f1(frame_ref["state"]["requested_slots"],
241
                    frame_hyp["state"]["requested_slots"])
242

243

244
def get_average_and_joint_goal_accuracy(frame_ref, frame_hyp, service,
245
                                        use_fuzzy_match):
246
  """Get average and joint goal accuracies of a frame.
247

248
  Args:
249
    frame_ref: single semantic frame from reference (ground truth) file.
250
    frame_hyp: single semantic frame from hypothesis (prediction) file.
251
    service: a service data structure in the schema. We use it to obtain the
252
      list of slots in the service and infer whether a slot is categorical.
253
    use_fuzzy_match: whether to use fuzzy string matching for comparing
254
      non-categorical slot values.
255

256
  Returns:
257
    goal_acc: a dict whose values are average / joint
258
        all-goal / categorical-goal / non-categorical-goal accuracies.
259
  """
260
  goal_acc = {}
261

262
  list_acc, slot_active, slot_cat = compare_slot_values(
263
      frame_ref["state"]["slot_values"], frame_hyp["state"]["slot_values"],
264
      service, use_fuzzy_match)
265

266
  # (4) Average goal accuracy.
267
  active_acc = [acc for acc, active in zip(list_acc, slot_active) if active]
268
  goal_acc[AVERAGE_GOAL_ACCURACY] = np.mean(
269
      active_acc) if active_acc else NAN_VAL
270
  # (4-a) categorical.
271
  active_cat_acc = [
272
      acc for acc, active, cat in zip(list_acc, slot_active, slot_cat)
273
      if active and cat
274
  ]
275
  goal_acc[AVERAGE_CAT_ACCURACY] = (
276
      np.mean(active_cat_acc) if active_cat_acc else NAN_VAL)
277
  # (4-b) non-categorical.
278
  active_noncat_acc = [
279
      acc for acc, active, cat in zip(list_acc, slot_active, slot_cat)
280
      if active and not cat
281
  ]
282
  goal_acc[AVERAGE_NONCAT_ACCURACY] = (
283
      np.mean(active_noncat_acc) if active_noncat_acc else NAN_VAL)
284

285
  # (5) Joint goal accuracy.
286
  goal_acc[JOINT_GOAL_ACCURACY] = np.prod(list_acc) if list_acc else NAN_VAL
287
  # (5-a) categorical.
288
  cat_acc = [acc for acc, cat in zip(list_acc, slot_cat) if cat]
289
  goal_acc[JOINT_CAT_ACCURACY] = np.prod(cat_acc) if cat_acc else NAN_VAL
290
  # (5-b) non-categorical.
291
  noncat_acc = [acc for acc, cat in zip(list_acc, slot_cat) if not cat]
292
  goal_acc[JOINT_NONCAT_ACCURACY] = np.prod(
293
      noncat_acc) if noncat_acc else NAN_VAL
294

295
  return goal_acc
296

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.