google-research
295 строк · 10.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Evaluation metrics for Schema-guided dialogue.
17
18This library provides functions for calculating the evaluation metrics for a
19single dialogue. The following metrics are defined:
20
21(1) Active intent accuracy: The fraction of user turns for which the active
22intent has been correctly predicted.
23(2) Slot tagging F1: The macro-averaged F1 score for tagging slot values for
24non-categorical slots. This metric is optional to report in the final paper
25if participants decide not to use slot tagging.
26(3) Requested slots F1: The macro-averaged F1 score for requested slots over the
27turns. For a turn, if there are no requested slots in both the ground truth
28and the prediction, that turn is skipped. The reported number is the average
29F1 score for all un-skipped user turns. This metric is optional to report in
30the final paper.
31(4) Average goal accuracy: For each turn, participants must predict a single
32value for each slot present in the dialogue state. The slots which have a
33non-empty assignment in the ground truth dialogue state are only considered.
34This is the average accuracy of predicting the value of a slot correctly. A
35fuzzy matching based score is used for non-categorical slots.
36(5) Joint goal accuracy: This is the average accuracy of predicting all slot
37assignments for a turn correctly. A fuzzy matching based score is used for
38non-categorical slots. This is the primary evaluation metric used for ranking
39submissions. More details to follow with the evaluation script.
40"""
41
42from __future__ import absolute_import
43from __future__ import division
44from __future__ import print_function
45
46import collections
47from fuzzywuzzy import fuzz
48import numpy as np
49
50F1Scores = collections.namedtuple("F1Scores", ["f1", "precision", "recall"])
51
52# Evaluation and other relevant metrics for DSTC8 Schema-guided DST.
53# (1) Active intent accuracy.
54ACTIVE_INTENT_ACCURACY = "active_intent_accuracy"
55# (2) Slot tagging F1.
56SLOT_TAGGING_F1 = "slot_tagging_f1"
57SLOT_TAGGING_PRECISION = "slot_tagging_precision"
58SLOT_TAGGING_RECALL = "slot_tagging_recall"
59# (3) Requested slots F1.
60REQUESTED_SLOTS_F1 = "requested_slots_f1"
61REQUESTED_SLOTS_PRECISION = "requested_slots_precision"
62REQUESTED_SLOTS_RECALL = "requested_slots_recall"
63# (4) Average goal accuracy.
64AVERAGE_GOAL_ACCURACY = "average_goal_accuracy"
65AVERAGE_CAT_ACCURACY = "average_cat_accuracy"
66AVERAGE_NONCAT_ACCURACY = "average_noncat_accuracy"
67# (5) Joint goal accuracy.
68JOINT_GOAL_ACCURACY = "joint_goal_accuracy"
69JOINT_CAT_ACCURACY = "joint_cat_accuracy"
70JOINT_NONCAT_ACCURACY = "joint_noncat_accuracy"
71
72NAN_VAL = "NA"
73
74
75def compute_f1(list_ref, list_hyp):
76"""Compute F1 score from reference (grouth truth) list and hypothesis list.
77
78Args:
79list_ref: List of true elements.
80list_hyp: List of postive (retrieved) elements.
81
82Returns:
83A F1Scores object containing F1, precision, and recall scores.
84"""
85
86ref = collections.Counter(list_ref)
87hyp = collections.Counter(list_hyp)
88true = sum(ref.values())
89positive = sum(hyp.values())
90true_positive = sum((ref & hyp).values())
91precision = float(true_positive) / positive if positive else 1.0
92recall = float(true_positive) / true if true else 1.0
93if precision + recall > 0.0:
94f1 = 2.0 * precision * recall / (precision + recall)
95else: # The F1-score is defined to be 0 if both precision and recall are 0.
96f1 = 0.0
97
98return F1Scores(f1=f1, precision=precision, recall=recall)
99
100
101def fuzzy_string_match(str_ref, str_hyp):
102"""Returns fuzzy string similarity score in range [0.0, 1.0]."""
103
104# The higher the score, the higher the similarity between the two strings.
105return fuzz.token_sort_ratio(str_ref, str_hyp) / 100.0
106
107
108def noncat_slot_value_match(str_ref_list, str_hyp, use_fuzzy_match):
109"""Calculate non-categorical slots correctness.
110
111Args:
112str_ref_list: a list of reference strings.
113str_hyp: the hypothesis string.
114use_fuzzy_match: whether to use fuzzy string matching.
115
116Returns:
117score: The highest fuzzy string match score of the references and hypotheis.
118"""
119score = 0.0
120for str_ref in str_ref_list:
121if not use_fuzzy_match:
122match_score = float(str_ref == str_hyp)
123else:
124match_score = fuzzy_string_match(str_ref, str_hyp)
125score = max(score, match_score)
126return score
127
128
129def compare_slot_values(slot_values_ref, slot_values_hyp, service,
130use_fuzzy_match):
131"""Compare and get correctness of goal state's slot_values.
132
133Args:
134slot_values_ref: goal state slot_values from reference (ground truth).
135slot_values_hyp: goal state slot_values from hypothesis (prediction).
136service: a service data structure in the schema. We use it to obtain the
137list of slots in the service and infer whether a slot is categorical.
138use_fuzzy_match: whether to use fuzzy string matching for non-categorical
139slot values.
140
141Returns:
142(list_cor, slot_active, slot_cat)
143list_cor: list of corectness scores, each corresponding to one slot in the
144service. The score is a float either 0.0 or 1.0 for categorical slot,
145and in range [0.0, 1.0] for non-categorical slot.
146slot_active: list indicating whether the element in list_cor corresponds to
147an active ground-truth slot.
148slot_cat: list indicating whether the element in list_cor corresponds to a
149categorical slot.
150"""
151list_cor = []
152slot_active = []
153slot_cat = []
154
155for slot in service["slots"]:
156slot_name = slot["name"]
157slot_cat.append(slot["is_categorical"])
158
159if slot_name in slot_values_ref: # REF=active
160slot_active.append(True)
161if slot_name in slot_values_hyp: # HYP=active, apply matching
162value_ref_list = slot_values_ref[slot_name]
163value_hyp = slot_values_hyp[slot_name][0]
164if slot["is_categorical"]:
165cor = float(value_ref_list[0].lower() == value_hyp.lower())
166else:
167cor = noncat_slot_value_match(value_ref_list, value_hyp,
168use_fuzzy_match)
169
170list_cor.append(cor)
171else: # HYP=off
172list_cor.append(0.0)
173else: # REF=off
174slot_active.append(False)
175if slot_name in slot_values_hyp: # HYP=active
176list_cor.append(0.0)
177else: # HYP=off
178list_cor.append(1.0)
179
180assert len(list_cor) == len(service["slots"])
181assert len(slot_active) == len(service["slots"])
182assert len(slot_cat) == len(service["slots"])
183return list_cor, slot_active, slot_cat
184
185
186def get_active_intent_accuracy(frame_ref, frame_hyp):
187"""Get active intent accuracy of a frame.
188
189Args:
190frame_ref: single semantic frame from reference (ground truth) file.
191frame_hyp: single semantic frame from hypothesis (prediction) file.
192
193Returns:
1941.0 if the intent prediction is correct, otherwise 0.0.
195"""
196return float(frame_ref["state"]["active_intent"].lower() == frame_hyp["state"]
197["active_intent"].lower())
198
199
200def get_slot_tagging_f1(frame_ref, frame_hyp, utt, service):
201"""Get slot tagging (non-categorical slots only) F1 scores of a frame.
202
203Args:
204frame_ref: single semantic frame from reference (ground truth) file.
205frame_hyp: single semantic frame from hypothesis (prediction) file.
206utt: user utterance. Slot tagging annotations are the character positions in
207the utterance.
208service: a service data structure in the schema. We use it to infer whether
209a slot is non-categorical.
210
211Returns:
212A F1Scores object containing F1, precision, and recall scores.
213"""
214
215list_noncat_slots = [
216s["name"] for s in service["slots"] if not s["is_categorical"]
217]
218if "slots" not in frame_hyp:
219return None
220else:
221list_ref = [(s["slot"], utt[s["start"]:s["exclusive_end"]])
222for s in frame_ref["slots"]
223if s["slot"] in list_noncat_slots]
224list_hyp = [(s["slot"], utt[s["start"]:s["exclusive_end"]])
225for s in frame_hyp["slots"]
226if s["slot"] in list_noncat_slots]
227return compute_f1(list_ref, list_hyp)
228
229
230def get_requested_slots_f1(frame_ref, frame_hyp):
231"""Get requested slots F1 scores of a frame.
232
233Args:
234frame_ref: single semantic frame from reference (ground truth) file.
235frame_hyp: single semantic frame from hypothesis (prediction) file.
236
237Returns:
238A F1Scores object containing F1, precision, and recall scores.
239"""
240return compute_f1(frame_ref["state"]["requested_slots"],
241frame_hyp["state"]["requested_slots"])
242
243
244def get_average_and_joint_goal_accuracy(frame_ref, frame_hyp, service,
245use_fuzzy_match):
246"""Get average and joint goal accuracies of a frame.
247
248Args:
249frame_ref: single semantic frame from reference (ground truth) file.
250frame_hyp: single semantic frame from hypothesis (prediction) file.
251service: a service data structure in the schema. We use it to obtain the
252list of slots in the service and infer whether a slot is categorical.
253use_fuzzy_match: whether to use fuzzy string matching for comparing
254non-categorical slot values.
255
256Returns:
257goal_acc: a dict whose values are average / joint
258all-goal / categorical-goal / non-categorical-goal accuracies.
259"""
260goal_acc = {}
261
262list_acc, slot_active, slot_cat = compare_slot_values(
263frame_ref["state"]["slot_values"], frame_hyp["state"]["slot_values"],
264service, use_fuzzy_match)
265
266# (4) Average goal accuracy.
267active_acc = [acc for acc, active in zip(list_acc, slot_active) if active]
268goal_acc[AVERAGE_GOAL_ACCURACY] = np.mean(
269active_acc) if active_acc else NAN_VAL
270# (4-a) categorical.
271active_cat_acc = [
272acc for acc, active, cat in zip(list_acc, slot_active, slot_cat)
273if active and cat
274]
275goal_acc[AVERAGE_CAT_ACCURACY] = (
276np.mean(active_cat_acc) if active_cat_acc else NAN_VAL)
277# (4-b) non-categorical.
278active_noncat_acc = [
279acc for acc, active, cat in zip(list_acc, slot_active, slot_cat)
280if active and not cat
281]
282goal_acc[AVERAGE_NONCAT_ACCURACY] = (
283np.mean(active_noncat_acc) if active_noncat_acc else NAN_VAL)
284
285# (5) Joint goal accuracy.
286goal_acc[JOINT_GOAL_ACCURACY] = np.prod(list_acc) if list_acc else NAN_VAL
287# (5-a) categorical.
288cat_acc = [acc for acc, cat in zip(list_acc, slot_cat) if cat]
289goal_acc[JOINT_CAT_ACCURACY] = np.prod(cat_acc) if cat_acc else NAN_VAL
290# (5-b) non-categorical.
291noncat_acc = [acc for acc, cat in zip(list_acc, slot_cat) if not cat]
292goal_acc[JOINT_NONCAT_ACCURACY] = np.prod(
293noncat_acc) if noncat_acc else NAN_VAL
294
295return goal_acc
296