google-research
288 строк · 11.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Evaluate predictions JSON file, w.r.t. ground truth file."""
17
18from __future__ import absolute_import19from __future__ import division20from __future__ import print_function21
22import collections23import json24import os25import numpy as np26import tensorflow.compat.v1 as tf27
28from schema_guided_dst import metrics29
30flags = tf.flags31FLAGS = flags.FLAGS32
33flags.DEFINE_string(34"prediction_dir", None,35"Directory in which all JSON files combined are predictions of the"36" evaluation set on a single model checkpoint. We evaluate these JSON files"37" by DSTC8 metrics.")38flags.DEFINE_string(39"dstc8_data_dir", None,40"Directory for the downloaded DSTC8 data, which contains the dialogue files"41" and schema files of all datasets (train, dev, test)")42flags.DEFINE_enum("eval_set", None, ["train", "dev", "test"],43"Dataset split for evaluation.")44flags.DEFINE_string(45"output_metric_file", None,46"Single JSON output file containing aggregated evaluation metrics results"47" for all predictions files in FLAGS.prediction_dir.")48flags.DEFINE_boolean(49"joint_acc_across_turn", False,50"Whether to compute joint accuracy across turn instead of across service. "51"Should be set to True when conducting multiwoz style evaluation.")52flags.DEFINE_boolean(53"use_fuzzy_match", True,54"Whether to use fuzzy string matching when comparing non-categorical slot "55"values. Should be set to False when conducting multiwoz style evaluation.")56
57ALL_SERVICES = "#ALL_SERVICES"58SEEN_SERVICES = "#SEEN_SERVICES"59UNSEEN_SERVICES = "#UNSEEN_SERVICES"60
61# Name of the file containing all predictions and their corresponding frame
62# metrics.
63PER_FRAME_OUTPUT_FILENAME = "dialogues_and_metrics.json"64
65
66def get_service_set(schema_path):67"""Get the set of all services present in a schema."""68service_set = set()69with tf.gfile.GFile(schema_path) as f:70schema = json.load(f)71for service in schema:72service_set.add(service["service_name"])73return service_set74
75
76def get_in_domain_services(schema_path_1, schema_path_2):77"""Get the set of common services between two schemas."""78return get_service_set(schema_path_1) & get_service_set(schema_path_2)79
80
81def get_dataset_as_dict(file_path_patterns):82"""Read the DSTC8 json dialog data as dictionary with dialog ID as keys."""83dataset_dict = {}84if isinstance(file_path_patterns, list):85list_fp = file_path_patterns86else:87list_fp = sorted(tf.gfile.Glob(file_path_patterns))88for fp in list_fp:89if PER_FRAME_OUTPUT_FILENAME in fp:90continue91tf.logging.info("Loading file: %s", fp)92with tf.gfile.GFile(fp) as f:93data = json.load(f)94if isinstance(data, list):95for dial in data:96dataset_dict[dial["dialogue_id"]] = dial97elif isinstance(data, dict):98dataset_dict.update(data)99return dataset_dict100
101
102def get_metrics(dataset_ref, dataset_hyp, service_schemas, in_domain_services):103"""Calculate the DSTC8 metrics.104
105Args:
106dataset_ref: The ground truth dataset represented as a dict mapping dialogue
107id to the corresponding dialogue.
108dataset_hyp: The predictions in the same format as `dataset_ref`.
109service_schemas: A dict mapping service name to the schema for the service.
110in_domain_services: The set of services which are present in the training
111set.
112
113Returns:
114A dict mapping a metric collection name to a dict containing the values
115for various metrics. Each metric collection aggregates the metrics across
116a specific set of frames in the dialogues.
117"""
118# Metrics can be aggregated in various ways, eg over all dialogues, only for119# dialogues containing unseen services or for dialogues corresponding to a120# single service. This aggregation is done through metric_collections, which121# is a dict mapping a collection name to a dict, which maps a metric to a list122# of values for that metric. Each value in this list is the value taken by123# the metric on a frame.124metric_collections = collections.defaultdict(125lambda: collections.defaultdict(list))126
127# Ensure the dialogs in dataset_hyp also occur in dataset_ref.128assert set(dataset_hyp.keys()).issubset(set(dataset_ref.keys()))129
130# Store metrics for every frame for debugging.131per_frame_metric = {}132for dial_id, dial_hyp in dataset_hyp.items():133dial_ref = dataset_ref[dial_id]134
135if set(dial_ref["services"]) != set(dial_hyp["services"]):136raise ValueError(137"Set of services present in ground truth and predictions don't match "138"for dialogue with id {}".format(dial_id))139joint_metrics = [140metrics.JOINT_GOAL_ACCURACY, metrics.JOINT_CAT_ACCURACY,141metrics.JOINT_NONCAT_ACCURACY142]143for turn_id, (turn_ref, turn_hyp) in enumerate(144zip(dial_ref["turns"], dial_hyp["turns"])):145metric_collections_per_turn = collections.defaultdict(146lambda: collections.defaultdict(lambda: 1.0))147if turn_ref["speaker"] != turn_hyp["speaker"]:148raise ValueError(149"Speakers don't match in dialogue with id {}".format(dial_id))150
151# Skip system turns because metrics are only computed for user turns.152if turn_ref["speaker"] != "USER":153continue154
155if turn_ref["utterance"] != turn_hyp["utterance"]:156tf.logging.info("Ref utt: %s", turn_ref["utterance"])157tf.logging.info("Hyp utt: %s", turn_hyp["utterance"])158raise ValueError(159"Utterances don't match for dialogue with id {}".format(dial_id))160
161hyp_frames_by_service = {162frame["service"]: frame for frame in turn_hyp["frames"]163}164
165# Calculate metrics for each frame in each user turn.166for frame_ref in turn_ref["frames"]:167service_name = frame_ref["service"]168if service_name not in hyp_frames_by_service:169raise ValueError(170"Frame for service {} not found in dialogue with id {}".format(171service_name, dial_id))172service = service_schemas[service_name]173frame_hyp = hyp_frames_by_service[service_name]174
175active_intent_acc = metrics.get_active_intent_accuracy(176frame_ref, frame_hyp)177slot_tagging_f1_scores = metrics.get_slot_tagging_f1(178frame_ref, frame_hyp, turn_ref["utterance"], service)179requested_slots_f1_scores = metrics.get_requested_slots_f1(180frame_ref, frame_hyp)181goal_accuracy_dict = metrics.get_average_and_joint_goal_accuracy(182frame_ref, frame_hyp, service, FLAGS.use_fuzzy_match)183
184frame_metric = {185metrics.ACTIVE_INTENT_ACCURACY:186active_intent_acc,187metrics.REQUESTED_SLOTS_F1:188requested_slots_f1_scores.f1,189metrics.REQUESTED_SLOTS_PRECISION:190requested_slots_f1_scores.precision,191metrics.REQUESTED_SLOTS_RECALL:192requested_slots_f1_scores.recall193}194if slot_tagging_f1_scores is not None:195frame_metric[metrics.SLOT_TAGGING_F1] = slot_tagging_f1_scores.f1196frame_metric[metrics.SLOT_TAGGING_PRECISION] = (197slot_tagging_f1_scores.precision)198frame_metric[199metrics.SLOT_TAGGING_RECALL] = slot_tagging_f1_scores.recall200frame_metric.update(goal_accuracy_dict)201
202frame_id = "{:s}-{:03d}-{:s}".format(dial_id, turn_id,203frame_hyp["service"])204per_frame_metric[frame_id] = frame_metric205# Add the frame-level metric result back to dialogues.206frame_hyp["metrics"] = frame_metric207
208# Get the domain name of the service.209domain_name = frame_hyp["service"].split("_")[0]210domain_keys = [ALL_SERVICES, frame_hyp["service"], domain_name]211if frame_hyp["service"] in in_domain_services:212domain_keys.append(SEEN_SERVICES)213else:214domain_keys.append(UNSEEN_SERVICES)215for domain_key in domain_keys:216for metric_key, metric_value in frame_metric.items():217if metric_value != metrics.NAN_VAL:218if FLAGS.joint_acc_across_turn and metric_key in joint_metrics:219metric_collections_per_turn[domain_key][220metric_key] *= metric_value221else:222metric_collections[domain_key][metric_key].append(metric_value)223if FLAGS.joint_acc_across_turn:224# Conduct multiwoz style evaluation that computes joint goal accuracy225# across all the slot values of all the domains for each turn.226for domain_key in metric_collections_per_turn:227for metric_key, metric_value in metric_collections_per_turn[228domain_key].items():229metric_collections[domain_key][metric_key].append(metric_value)230all_metric_aggregate = {}231for domain_key, domain_metric_vals in metric_collections.items():232domain_metric_aggregate = {}233for metric_key, value_list in domain_metric_vals.items():234if value_list:235# Metrics are macro-averaged across all frames.236domain_metric_aggregate[metric_key] = float(np.mean(value_list))237else:238domain_metric_aggregate[metric_key] = metrics.NAN_VAL239all_metric_aggregate[domain_key] = domain_metric_aggregate240return all_metric_aggregate, per_frame_metric241
242
243def main(_):244tf.logging.set_verbosity(tf.logging.INFO)245
246in_domain_services = get_in_domain_services(247os.path.join(FLAGS.dstc8_data_dir, FLAGS.eval_set, "schema.json"),248os.path.join(FLAGS.dstc8_data_dir, "train", "schema.json"))249with tf.io.gfile.GFile(250os.path.join(FLAGS.dstc8_data_dir, FLAGS.eval_set, "schema.json")) as f:251eval_services = {}252list_services = json.load(f)253for service in list_services:254eval_services[service["service_name"]] = service255
256dataset_ref = get_dataset_as_dict(257os.path.join(FLAGS.dstc8_data_dir, FLAGS.eval_set, "dialogues_*.json"))258dataset_hyp = get_dataset_as_dict(259os.path.join(FLAGS.prediction_dir, "*.json"))260tf.logging.info("len(dataset_hyp)=%d, len(dataset_ref)=%d", len(dataset_hyp),261len(dataset_ref))262if not dataset_hyp or not dataset_ref:263raise ValueError("Hypothesis and/or reference dataset are empty!")264
265all_metric_aggregate, _ = get_metrics(dataset_ref, dataset_hyp, eval_services,266in_domain_services)267tf.logging.info("Dialog metrics: %s", str(all_metric_aggregate[ALL_SERVICES]))268
269# Write the aggregated metrics values.270with tf.gfile.GFile(FLAGS.output_metric_file, "w") as f:271json.dump(272all_metric_aggregate,273f,274indent=2,275separators=(",", ": "),276sort_keys=True)277# Write the per-frame metrics values with the corrresponding dialogue frames.278with tf.gfile.GFile(279os.path.join(FLAGS.prediction_dir, PER_FRAME_OUTPUT_FILENAME), "w") as f:280json.dump(dataset_hyp, f, indent=2, separators=(",", ": "))281
282
283if __name__ == "__main__":284flags.mark_flag_as_required("prediction_dir")285flags.mark_flag_as_required("dstc8_data_dir")286flags.mark_flag_as_required("eval_set")287flags.mark_flag_as_required("output_metric_file")288tf.compat.v1.app.run(main)289