google-research
422 строки · 15.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Applies a Contrack model on some new data."""
17
18import collections19import json20import logging21import os22from typing import Dict, List, Text, Tuple23
24from absl import app25from absl import flags26import tensorflow as tf27
28from contrack import data29from contrack import encoding30from contrack import env31from contrack import model32from contrack import signals33
34flags.DEFINE_string('model_path', '',35'Path to directory where the model is stored.')36flags.DEFINE_bool(37'eval', True, 'If true, compare with target label containd in the input '38'data and output accuracy metrics.')39flags.DEFINE_string(40'input_data_glob', '',41'A TF glob pattern specifying the location of the evaluation data files.')42flags.DEFINE_string(43'clusters_file', '',44'A jsonline file to which the predicted clusters are added.')45flags.DEFINE_bool(46'teacher_forcing', True,47'If true, use true instead of predicted labels for repository.')48
49FLAGS = flags.FLAGS50
51PRONOUNS = [52'i', 'me', 'my', 'you', 'your', 'he', 'him', 'his', 'she', 'her', 'we',53'our', 'us', 'they', 'their', 'them', 'there', 'here', 'it'54]
55
56METRICS = [('people', 'new_entity'), ('people', 'entities'),57('people', 'properties'), ('people', 'membership'),58('locations', 'new_entity'), ('locations', 'entities'),59('locations', 'properties'), ('locations', 'membership'),60('all', 'new_entity'), ('all', 'entities'),61('all', 'properties'), ('all', 'membership')]62
63EPSILON = 1e-1064
65
66def _get_named_slices(y_true, logits,67section_name):68"""Returns the slices (given by name) of true and predictied vector."""69is_entity = y_true.enref_meta.is_enref()70if section_name == 'new_entity':71return (y_true.enref_meta.get_is_new_slice(),72is_entity * logits.enref_meta.get_is_new_slice())73elif section_name == 'entities':74return (y_true.enref_id.slice(), is_entity * logits.enref_id.slice())75elif section_name == 'properties':76return (y_true.enref_properties.slice(),77is_entity * logits.enref_properties.slice())78elif section_name == 'membership':79is_group = y_true.enref_properties.is_group()80return (y_true.enref_membership.slice(),81is_entity * is_group * logits.enref_membership.slice())82else:83raise ValueError('Unknown section name %s' % section_name)84
85
86def _compute_stats(x, y_pred,87environment):88"""Computes statistics about accuracy on enrefs in certain categories."""89encodings = environment.encodings90stats = {}91for m in METRICS:92stats[f'{m[0]}/{m[1]}/tp'] = 093stats[f'{m[0]}/{m[1]}/fp'] = 094stats[f'{m[0]}/{m[1]}/fn'] = 095stats['people/stats'] = [0, 0, 0]96stats['locations/stats'] = [0, 0, 0]97for i in range(0, 30):98stats[f'by_turn/{i}'] = [0, 0, 0]99
100other_entity_tokens = collections.defaultdict(int)101
102turn_nr = 0103prev_scenario_id = ''104for i in range(0, y_pred.shape[0]):105if x['scenario_id'][i].decode('utf-8') == prev_scenario_id:106turn_nr += 1107else:108turn_nr = 0109prev_scenario_id = x['scenario_id'][i].decode('utf-8')110
111for j in range(0, x['token_seq_length'][i]):112true_enc = encodings.as_prediction_encoding(x['annotation_seq'][i, j, :])113pred_index = x['state_seq_length'][i] + j114if pred_index >= environment.config.max_seq_len:115continue116pred_enc = encodings.as_prediction_encoding(y_pred[i, pred_index, :])117
118if true_enc.enref_meta.is_enref() > 0.0:119word = x['word_seq'][i, j, 0].decode('utf-8')120
121if word in signals.FEMALE_NAMES:122word = 'FFN'123elif word in signals.MALE_NAMES:124word = 'MFN'125elif word not in PRONOUNS:126other_entity_tokens[word] += 1127word = 'OTHER'128
129if word not in stats:130stats[word] = [0, 0, 0]131
132stats[word][0] += 1133stats[true_enc.enref_properties.get_domain() + '/stats'][0] += 1134stats[f'by_turn/{turn_nr}'][0] += 1135if pred_enc.enref_meta.is_enref() > 0.0:136stats[word][1] += 1137stats[true_enc.enref_properties.get_domain() + '/stats'][1] += 1138stats[f'by_turn/{turn_nr}'][1] += 1139if pred_enc.enref_id.get() == true_enc.enref_id.get():140stats[word][2] += 1141stats[true_enc.enref_properties.get_domain() + '/stats'][2] += 1142stats[f'by_turn/{turn_nr}'][2] += 1143
144for m in METRICS:145if (m[0] != 'all' and146m[0] != true_enc.enref_properties.get_domain()):147continue148true_y, logits = _get_named_slices(true_enc, pred_enc, m[1])149pred_y = tf.cast(logits > 0.0, tf.float32)150
151stats[f'{m[0]}/{m[1]}/tp'] += tf.reduce_sum(true_y * pred_y).numpy()152stats[f'{m[0]}/{m[1]}/fp'] += tf.reduce_sum(153(1.0 - true_y) * pred_y).numpy()154stats[f'{m[0]}/{m[1]}/fn'] += tf.reduce_sum(155true_y * (1.0 - pred_y)).numpy()156
157for m in METRICS:158stats[f'{m[0]}/{m[1]}/pr'] = round(159stats[f'{m[0]}/{m[1]}/tp'] /160(stats[f'{m[0]}/{m[1]}/tp'] + stats[f'{m[0]}/{m[1]}/fp'] + EPSILON), 3)161stats[f'{m[0]}/{m[1]}/re'] = round(162stats[f'{m[0]}/{m[1]}/tp'] /163(stats[f'{m[0]}/{m[1]}/tp'] + stats[f'{m[0]}/{m[1]}/fn'] + EPSILON), 3)164stats[f'{m[0]}/{m[1]}/f1'] = round(1652.0 * (stats[f'{m[0]}/{m[1]}/pr'] * stats[f'{m[0]}/{m[1]}/re']) /166(stats[f'{m[0]}/{m[1]}/pr'] + stats[f'{m[0]}/{m[1]}/re'] + EPSILON), 3)167
168return stats, other_entity_tokens169
170
171def find_cluster(stats, m1, m2):172"""Checks if m1 and m2 are in true or predited cluster."""173in_true_cluster = False174in_pred_cluster = False175for cluster in stats['true_clusters'].values():176if m1 in cluster and m2 in cluster:177in_true_cluster = True178break179for cluster in stats['pred_clusters'].values():180if m1 in cluster and m2 in cluster:181in_pred_cluster = True182break183return in_true_cluster, in_pred_cluster184
185
186def _compute_entity_tracking_stats(x, y_pred,187environment):188"""Computes statistics about accuracy on enrefs in certain categories."""189encodings = environment.encodings190el_stats = {}191for category in ['singular', 'plural', 'both']:192el_stats.update({193f'{category}_true': 0,194f'{category}_pred': 0,195f'{category}_correct': 0,196})197
198scene_stats = {199'id': '',200'm_id': 0,201'true_clusters': collections.defaultdict(set),202'pred_clusters': collections.defaultdict(set)203}204
205blanc_stats = [[0, 0], [0, 0]]206
207for i in range(0, y_pred.shape[0]):208for j in range(0, x['token_seq_length'][i]):209true_enc = encodings.as_prediction_encoding(x['annotation_seq'][i, j, :])210pred_index = x['state_seq_length'][i] + j211if pred_index >= environment.config.max_seq_len:212continue213pred_enc = encodings.as_prediction_encoding(y_pred[i, pred_index, :])214
215# Collect stats for Entity Linking F1 score216true_entities = []217if true_enc.enref_meta.is_enref() > 0.0:218if true_enc.enref_properties.is_group() > 0.0:219true_entities = true_enc.enref_membership.get_ids()220el_stats['plural_true'] += len(true_entities)221el_stats['both_true'] += len(true_entities)222else:223true_entities = [true_enc.enref_id.get()]224el_stats['singular_true'] += 1225el_stats['both_true'] += 1226
227pred_entities = []228if pred_enc.enref_meta.is_enref() > 0.0:229if pred_enc.enref_properties.is_group() > 0.0:230pred_entities = pred_enc.enref_membership.get_ids()231el_stats['plural_pred'] += len(pred_entities)232el_stats['both_pred'] += len(pred_entities)233else:234pred_entities = [pred_enc.enref_id.get()]235el_stats['singular_pred'] += 1236el_stats['both_pred'] += 1237
238for entity in true_entities:239if entity in pred_entities:240el_stats['both_correct'] += 1241if true_enc.enref_properties.is_group() > 0.0:242el_stats['plural_correct'] += 1243else:244el_stats['singular_correct'] += 1245
246# Collect stats for BLANC247scene_id = x['scenario_id'][i]248if not scene_stats['id']:249scene_stats['id'] = scene_id250m_id = scene_stats['m_id']251if scene_id != scene_stats['id']:252for m1 in range(0, m_id):253for m2 in range(0, m1):254in_true_cluster, in_pred_cluster = find_cluster(scene_stats, m1, m2)255blanc_stats[1 - int(in_true_cluster)][1 - int(in_pred_cluster)] += 1256scene_stats = {257'id': scene_id,258'm_id': 0,259'true_clusters': collections.defaultdict(set),260'pred_clusters': collections.defaultdict(set)261}262
263if true_enc.enref_meta.is_enref() > 0.0:264scene_stats['m_id'] += 1265if true_enc.enref_properties.is_group() > 0.0:266for e_id in true_enc.enref_membership.get_ids():267scene_stats['true_clusters'][e_id].add(m_id)268else:269scene_stats['true_clusters'][true_enc.enref_id.get()].add(m_id)270
271if pred_enc.enref_meta.is_enref() > 0.0:272if pred_enc.enref_properties.is_group() > 0.0:273for e_id in pred_enc.enref_membership.get_ids():274scene_stats['pred_clusters'][e_id].add(m_id)275else:276scene_stats['pred_clusters'][pred_enc.enref_id.get()].add(m_id)277
278el_results = {}279for c in ['singular', 'plural', 'both']:280el_results.update({281f'{c}_precision': el_stats[f'{c}_correct'] / el_stats[f'{c}_pred'],282f'{c}_recall': el_stats[f'{c}_correct'] / el_stats[f'{c}_true'],283})284el_results[f'{c}_f1'] = (2852.0 * (el_results[f'{c}_precision'] * el_results[f'{c}_recall']) /286(el_results[f'{c}_precision'] + el_results[f'{c}_recall']))287
288b = blanc_stats289logging.info('B: %s', b)290blanc_results = {291'Pc': b[0][0] / (b[0][0] + b[1][0]),292'Rc': b[0][0] / (b[0][0] + b[0][1]),293'Pn': b[1][1] / (b[1][1] + b[0][1]),294'Rn': b[1][1] / (b[1][1] + b[1][0]),295}296blanc_results['F1c'] = (2.0 * (blanc_results['Pc'] * blanc_results['Rc']) /297(blanc_results['Pc'] + blanc_results['Rc']))298blanc_results['F1n'] = (2.0 * (blanc_results['Pn'] * blanc_results['Rn']) /299(blanc_results['Pn'] + blanc_results['Rn']))300blanc_results.update({301'P': (blanc_results['Pc'] + blanc_results['Pn']) / 2.0,302'R': (blanc_results['Rc'] + blanc_results['Rn']) / 2.0,303'F1': (blanc_results['F1c'] + blanc_results['F1n']) / 2.0304})305
306return el_results, blanc_results307
308
309def _save_clusters(x, y_pred,310environment, file_name):311"""Saves clusters to jsonlines file."""312encodings = environment.encodings313examples = {}314with tf.io.gfile.GFile(file_name, 'r') as input_file:315for line in input_file:316example = json.loads(line)317examples[example['doc_key']] = example318
319prev_id = x['scenario_id'][0].decode('utf-8')320clusters = {}321num_tokens = 0322prev_enref = None323for i in range(0, y_pred.shape[0]):324s_id = x['scenario_id'][i].decode('utf-8')325if s_id != prev_id:326examples['tc/' + prev_id]['predicted_clusters'] = list(clusters.values())327prev_id = s_id328clusters = {}329num_tokens = 0330prev_enref = None331
332for j in range(0, x['token_seq_length'][i]):333token = x['word_seq'][i, j, 0].decode('utf-8')334if token.startswith('['):335continue336
337pred_index = x['state_seq_length'][i] + j338if pred_index >= environment.config.max_seq_len:339continue340pred_enc = encodings.as_prediction_encoding(y_pred[i, pred_index, :])341
342pred_id = -1343if (pred_enc.enref_meta.is_enref() > 0 and344pred_enc.enref_properties.is_group() <= 0):345pred_id = pred_enc.enref_id.get()346
347if prev_enref and (pred_id == -1 or pred_id != prev_enref[0]):348cl_id = prev_enref[0]349if cl_id not in clusters:350clusters[cl_id] = []351logging.info('Adding enref %s to cluster %s', prev_enref[2],352clusters[cl_id])353clusters[cl_id].append((prev_enref[1], num_tokens - 1))354prev_enref = None355if pred_id >= 0:356logging.info('Starting new enref: %s (%d)', token, num_tokens)357prev_enref = (pred_id, num_tokens, token)358
359if not token.startswith('##'):360logging.info('%s: %d', token, num_tokens)361num_tokens += 1362
363examples['tc/' + prev_id]['predicted_clusters'] = list(clusters.values())364
365teacher_forcing = 'withtf' if FLAGS.teacher_forcing else 'withouttf'366output_file_name = (os.path.splitext(file_name)[0] + '_predicted' +367teacher_forcing + '.jsonlines')368logging.info('Writing to %s', output_file_name)369with tf.io.gfile.GFile(output_file_name, 'w') as output_file:370for e in examples.values():371output_file.write(json.dumps(e) + '\n')372
373
374def main(argv):375del argv # Unused.376
377env.Env.init_from_saved_model(FLAGS.model_path)378environment = env.Env.get()379if not FLAGS.teacher_forcing:380environment.config.batch_size = 1381logging.info('Inference with config:\n%s', environment.config)382
383logging.info('Reading data from %s', FLAGS.input_data_glob)384input_data = data.read_eval_data(FLAGS.input_data_glob, environment.config,385environment.encodings)386
387with tf.keras.utils.custom_object_scope(model.get_custom_objects()):388contrack_model = tf.keras.models.load_model(FLAGS.model_path)389
390contrack_model.print_predictions = True391if not FLAGS.teacher_forcing:392contrack_model.compile(run_eagerly=True)393contrack_model.disable_teacher_forcing()394
395if FLAGS.eval:396contrack_model.evaluate(397input_data, batch_size=environment.config.batch_size)398else:399x, y_pred = contrack_model.predict(400input_data, batch_size=environment.config.batch_size, verbose=1)401
402stats, other_entities = _compute_stats(x, y_pred, environment)403logging.info('Accuracy Stats:')404for k, v in stats.items():405logging.info('%s: %s', k, v)406logging.info('Other entities: ')407for word, count in sorted(other_entities.items(), key=lambda w: -w[1]):408if count < 5:409break410logging.info('%s: %d', word, count)411
412el_stats, blanc_stats = _compute_entity_tracking_stats(x, y_pred,413environment)414logging.info('entity linking results: %s', str(el_stats))415logging.info('BLANC results: %s', str(blanc_stats))416
417if FLAGS.clusters_file:418_save_clusters(x, y_pred, environment, FLAGS.clusters_file)419
420
421if __name__ == '__main__':422app.run(main)423