google-research

Форк
0
288 строк · 11.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Evaluate predictions JSON file, w.r.t. ground truth file."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import collections
23
import json
24
import os
25
import numpy as np
26
import tensorflow.compat.v1 as tf
27

28
from schema_guided_dst import metrics
29

30
flags = tf.flags
31
FLAGS = flags.FLAGS
32

33
flags.DEFINE_string(
34
    "prediction_dir", None,
35
    "Directory in which all JSON files combined are predictions of the"
36
    " evaluation set on a single model checkpoint. We evaluate these JSON files"
37
    " by DSTC8 metrics.")
38
flags.DEFINE_string(
39
    "dstc8_data_dir", None,
40
    "Directory for the downloaded DSTC8 data, which contains the dialogue files"
41
    " and schema files of all datasets (train, dev, test)")
42
flags.DEFINE_enum("eval_set", None, ["train", "dev", "test"],
43
                  "Dataset split for evaluation.")
44
flags.DEFINE_string(
45
    "output_metric_file", None,
46
    "Single JSON output file containing aggregated evaluation metrics results"
47
    " for all predictions files in FLAGS.prediction_dir.")
48
flags.DEFINE_boolean(
49
    "joint_acc_across_turn", False,
50
    "Whether to compute joint accuracy across turn instead of across service. "
51
    "Should be set to True when conducting multiwoz style evaluation.")
52
flags.DEFINE_boolean(
53
    "use_fuzzy_match", True,
54
    "Whether to use fuzzy string matching when comparing non-categorical slot "
55
    "values. Should be set to False when conducting multiwoz style evaluation.")
56

57
ALL_SERVICES = "#ALL_SERVICES"
58
SEEN_SERVICES = "#SEEN_SERVICES"
59
UNSEEN_SERVICES = "#UNSEEN_SERVICES"
60

61
# Name of the file containing all predictions and their corresponding frame
62
# metrics.
63
PER_FRAME_OUTPUT_FILENAME = "dialogues_and_metrics.json"
64

65

66
def get_service_set(schema_path):
67
  """Get the set of all services present in a schema."""
68
  service_set = set()
69
  with tf.gfile.GFile(schema_path) as f:
70
    schema = json.load(f)
71
    for service in schema:
72
      service_set.add(service["service_name"])
73
  return service_set
74

75

76
def get_in_domain_services(schema_path_1, schema_path_2):
77
  """Get the set of common services between two schemas."""
78
  return get_service_set(schema_path_1) & get_service_set(schema_path_2)
79

80

81
def get_dataset_as_dict(file_path_patterns):
82
  """Read the DSTC8 json dialog data as dictionary with dialog ID as keys."""
83
  dataset_dict = {}
84
  if isinstance(file_path_patterns, list):
85
    list_fp = file_path_patterns
86
  else:
87
    list_fp = sorted(tf.gfile.Glob(file_path_patterns))
88
  for fp in list_fp:
89
    if PER_FRAME_OUTPUT_FILENAME in fp:
90
      continue
91
    tf.logging.info("Loading file: %s", fp)
92
    with tf.gfile.GFile(fp) as f:
93
      data = json.load(f)
94
      if isinstance(data, list):
95
        for dial in data:
96
          dataset_dict[dial["dialogue_id"]] = dial
97
      elif isinstance(data, dict):
98
        dataset_dict.update(data)
99
  return dataset_dict
100

101

102
def get_metrics(dataset_ref, dataset_hyp, service_schemas, in_domain_services):
103
  """Calculate the DSTC8 metrics.
104

105
  Args:
106
    dataset_ref: The ground truth dataset represented as a dict mapping dialogue
107
      id to the corresponding dialogue.
108
    dataset_hyp: The predictions in the same format as `dataset_ref`.
109
    service_schemas: A dict mapping service name to the schema for the service.
110
    in_domain_services: The set of services which are present in the training
111
      set.
112

113
  Returns:
114
    A dict mapping a metric collection name to a dict containing the values
115
    for various metrics. Each metric collection aggregates the metrics across
116
    a specific set of frames in the dialogues.
117
  """
118
  # Metrics can be aggregated in various ways, eg over all dialogues, only for
119
  # dialogues containing unseen services or for dialogues corresponding to a
120
  # single service. This aggregation is done through metric_collections, which
121
  # is a dict mapping a collection name to a dict, which maps a metric to a list
122
  # of values for that metric. Each value in this list is the value taken by
123
  # the metric on a frame.
124
  metric_collections = collections.defaultdict(
125
      lambda: collections.defaultdict(list))
126

127
  # Ensure the dialogs in dataset_hyp also occur in dataset_ref.
128
  assert set(dataset_hyp.keys()).issubset(set(dataset_ref.keys()))
129

130
  # Store metrics for every frame for debugging.
131
  per_frame_metric = {}
132
  for dial_id, dial_hyp in dataset_hyp.items():
133
    dial_ref = dataset_ref[dial_id]
134

135
    if set(dial_ref["services"]) != set(dial_hyp["services"]):
136
      raise ValueError(
137
          "Set of services present in ground truth and predictions don't match "
138
          "for dialogue with id {}".format(dial_id))
139
    joint_metrics = [
140
        metrics.JOINT_GOAL_ACCURACY, metrics.JOINT_CAT_ACCURACY,
141
        metrics.JOINT_NONCAT_ACCURACY
142
    ]
143
    for turn_id, (turn_ref, turn_hyp) in enumerate(
144
        zip(dial_ref["turns"], dial_hyp["turns"])):
145
      metric_collections_per_turn = collections.defaultdict(
146
          lambda: collections.defaultdict(lambda: 1.0))
147
      if turn_ref["speaker"] != turn_hyp["speaker"]:
148
        raise ValueError(
149
            "Speakers don't match in dialogue with id {}".format(dial_id))
150

151
      # Skip system turns because metrics are only computed for user turns.
152
      if turn_ref["speaker"] != "USER":
153
        continue
154

155
      if turn_ref["utterance"] != turn_hyp["utterance"]:
156
        tf.logging.info("Ref utt: %s", turn_ref["utterance"])
157
        tf.logging.info("Hyp utt: %s", turn_hyp["utterance"])
158
        raise ValueError(
159
            "Utterances don't match for dialogue with id {}".format(dial_id))
160

161
      hyp_frames_by_service = {
162
          frame["service"]: frame for frame in turn_hyp["frames"]
163
      }
164

165
      # Calculate metrics for each frame in each user turn.
166
      for frame_ref in turn_ref["frames"]:
167
        service_name = frame_ref["service"]
168
        if service_name not in hyp_frames_by_service:
169
          raise ValueError(
170
              "Frame for service {} not found in dialogue with id {}".format(
171
                  service_name, dial_id))
172
        service = service_schemas[service_name]
173
        frame_hyp = hyp_frames_by_service[service_name]
174

175
        active_intent_acc = metrics.get_active_intent_accuracy(
176
            frame_ref, frame_hyp)
177
        slot_tagging_f1_scores = metrics.get_slot_tagging_f1(
178
            frame_ref, frame_hyp, turn_ref["utterance"], service)
179
        requested_slots_f1_scores = metrics.get_requested_slots_f1(
180
            frame_ref, frame_hyp)
181
        goal_accuracy_dict = metrics.get_average_and_joint_goal_accuracy(
182
            frame_ref, frame_hyp, service, FLAGS.use_fuzzy_match)
183

184
        frame_metric = {
185
            metrics.ACTIVE_INTENT_ACCURACY:
186
                active_intent_acc,
187
            metrics.REQUESTED_SLOTS_F1:
188
                requested_slots_f1_scores.f1,
189
            metrics.REQUESTED_SLOTS_PRECISION:
190
                requested_slots_f1_scores.precision,
191
            metrics.REQUESTED_SLOTS_RECALL:
192
                requested_slots_f1_scores.recall
193
        }
194
        if slot_tagging_f1_scores is not None:
195
          frame_metric[metrics.SLOT_TAGGING_F1] = slot_tagging_f1_scores.f1
196
          frame_metric[metrics.SLOT_TAGGING_PRECISION] = (
197
              slot_tagging_f1_scores.precision)
198
          frame_metric[
199
              metrics.SLOT_TAGGING_RECALL] = slot_tagging_f1_scores.recall
200
        frame_metric.update(goal_accuracy_dict)
201

202
        frame_id = "{:s}-{:03d}-{:s}".format(dial_id, turn_id,
203
                                             frame_hyp["service"])
204
        per_frame_metric[frame_id] = frame_metric
205
        # Add the frame-level metric result back to dialogues.
206
        frame_hyp["metrics"] = frame_metric
207

208
        # Get the domain name of the service.
209
        domain_name = frame_hyp["service"].split("_")[0]
210
        domain_keys = [ALL_SERVICES, frame_hyp["service"], domain_name]
211
        if frame_hyp["service"] in in_domain_services:
212
          domain_keys.append(SEEN_SERVICES)
213
        else:
214
          domain_keys.append(UNSEEN_SERVICES)
215
        for domain_key in domain_keys:
216
          for metric_key, metric_value in frame_metric.items():
217
            if metric_value != metrics.NAN_VAL:
218
              if FLAGS.joint_acc_across_turn and metric_key in joint_metrics:
219
                metric_collections_per_turn[domain_key][
220
                    metric_key] *= metric_value
221
              else:
222
                metric_collections[domain_key][metric_key].append(metric_value)
223
      if FLAGS.joint_acc_across_turn:
224
        # Conduct multiwoz style evaluation that computes joint goal accuracy
225
        # across all the slot values of all the domains for each turn.
226
        for domain_key in metric_collections_per_turn:
227
          for metric_key, metric_value in metric_collections_per_turn[
228
              domain_key].items():
229
            metric_collections[domain_key][metric_key].append(metric_value)
230
  all_metric_aggregate = {}
231
  for domain_key, domain_metric_vals in metric_collections.items():
232
    domain_metric_aggregate = {}
233
    for metric_key, value_list in domain_metric_vals.items():
234
      if value_list:
235
        # Metrics are macro-averaged across all frames.
236
        domain_metric_aggregate[metric_key] = float(np.mean(value_list))
237
      else:
238
        domain_metric_aggregate[metric_key] = metrics.NAN_VAL
239
    all_metric_aggregate[domain_key] = domain_metric_aggregate
240
  return all_metric_aggregate, per_frame_metric
241

242

243
def main(_):
244
  tf.logging.set_verbosity(tf.logging.INFO)
245

246
  in_domain_services = get_in_domain_services(
247
      os.path.join(FLAGS.dstc8_data_dir, FLAGS.eval_set, "schema.json"),
248
      os.path.join(FLAGS.dstc8_data_dir, "train", "schema.json"))
249
  with tf.io.gfile.GFile(
250
      os.path.join(FLAGS.dstc8_data_dir, FLAGS.eval_set, "schema.json")) as f:
251
    eval_services = {}
252
    list_services = json.load(f)
253
    for service in list_services:
254
      eval_services[service["service_name"]] = service
255

256
  dataset_ref = get_dataset_as_dict(
257
      os.path.join(FLAGS.dstc8_data_dir, FLAGS.eval_set, "dialogues_*.json"))
258
  dataset_hyp = get_dataset_as_dict(
259
      os.path.join(FLAGS.prediction_dir, "*.json"))
260
  tf.logging.info("len(dataset_hyp)=%d, len(dataset_ref)=%d", len(dataset_hyp),
261
                  len(dataset_ref))
262
  if not dataset_hyp or not dataset_ref:
263
    raise ValueError("Hypothesis and/or reference dataset are empty!")
264

265
  all_metric_aggregate, _ = get_metrics(dataset_ref, dataset_hyp, eval_services,
266
                                        in_domain_services)
267
  tf.logging.info("Dialog metrics: %s", str(all_metric_aggregate[ALL_SERVICES]))
268

269
  # Write the aggregated metrics values.
270
  with tf.gfile.GFile(FLAGS.output_metric_file, "w") as f:
271
    json.dump(
272
        all_metric_aggregate,
273
        f,
274
        indent=2,
275
        separators=(",", ": "),
276
        sort_keys=True)
277
  # Write the per-frame metrics values with the corrresponding dialogue frames.
278
  with tf.gfile.GFile(
279
      os.path.join(FLAGS.prediction_dir, PER_FRAME_OUTPUT_FILENAME), "w") as f:
280
    json.dump(dataset_hyp, f, indent=2, separators=(",", ": "))
281

282

283
if __name__ == "__main__":
284
  flags.mark_flag_as_required("prediction_dir")
285
  flags.mark_flag_as_required("dstc8_data_dir")
286
  flags.mark_flag_as_required("eval_set")
287
  flags.mark_flag_as_required("output_metric_file")
288
  tf.compat.v1.app.run(main)
289

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.