google-research

Форк
0
422 строки · 15.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Applies a Contrack model on some new data."""
17

18
import collections
19
import json
20
import logging
21
import os
22
from typing import Dict, List, Text, Tuple
23

24
from absl import app
25
from absl import flags
26
import tensorflow as tf
27

28
from contrack import data
29
from contrack import encoding
30
from contrack import env
31
from contrack import model
32
from contrack import signals
33

34
flags.DEFINE_string('model_path', '',
35
                    'Path to directory where the model is stored.')
36
flags.DEFINE_bool(
37
    'eval', True, 'If true, compare with target label containd in the input '
38
    'data and output accuracy metrics.')
39
flags.DEFINE_string(
40
    'input_data_glob', '',
41
    'A TF glob pattern specifying the location of the evaluation data files.')
42
flags.DEFINE_string(
43
    'clusters_file', '',
44
    'A jsonline file to which the predicted clusters are added.')
45
flags.DEFINE_bool(
46
    'teacher_forcing', True,
47
    'If true, use true instead of predicted labels for repository.')
48

49
FLAGS = flags.FLAGS
50

51
PRONOUNS = [
52
    'i', 'me', 'my', 'you', 'your', 'he', 'him', 'his', 'she', 'her', 'we',
53
    'our', 'us', 'they', 'their', 'them', 'there', 'here', 'it'
54
]
55

56
METRICS = [('people', 'new_entity'), ('people', 'entities'),
57
           ('people', 'properties'), ('people', 'membership'),
58
           ('locations', 'new_entity'), ('locations', 'entities'),
59
           ('locations', 'properties'), ('locations', 'membership'),
60
           ('all', 'new_entity'), ('all', 'entities'),
61
           ('all', 'properties'), ('all', 'membership')]
62

63
EPSILON = 1e-10
64

65

66
def _get_named_slices(y_true, logits,
67
                      section_name):
68
  """Returns the slices (given by name) of true and predictied vector."""
69
  is_entity = y_true.enref_meta.is_enref()
70
  if section_name == 'new_entity':
71
    return (y_true.enref_meta.get_is_new_slice(),
72
            is_entity * logits.enref_meta.get_is_new_slice())
73
  elif section_name == 'entities':
74
    return (y_true.enref_id.slice(), is_entity * logits.enref_id.slice())
75
  elif section_name == 'properties':
76
    return (y_true.enref_properties.slice(),
77
            is_entity * logits.enref_properties.slice())
78
  elif section_name == 'membership':
79
    is_group = y_true.enref_properties.is_group()
80
    return (y_true.enref_membership.slice(),
81
            is_entity * is_group * logits.enref_membership.slice())
82
  else:
83
    raise ValueError('Unknown section name %s' % section_name)
84

85

86
def _compute_stats(x, y_pred,
87
                   environment):
88
  """Computes statistics about accuracy on enrefs in certain categories."""
89
  encodings = environment.encodings
90
  stats = {}
91
  for m in METRICS:
92
    stats[f'{m[0]}/{m[1]}/tp'] = 0
93
    stats[f'{m[0]}/{m[1]}/fp'] = 0
94
    stats[f'{m[0]}/{m[1]}/fn'] = 0
95
  stats['people/stats'] = [0, 0, 0]
96
  stats['locations/stats'] = [0, 0, 0]
97
  for i in range(0, 30):
98
    stats[f'by_turn/{i}'] = [0, 0, 0]
99

100
  other_entity_tokens = collections.defaultdict(int)
101

102
  turn_nr = 0
103
  prev_scenario_id = ''
104
  for i in range(0, y_pred.shape[0]):
105
    if x['scenario_id'][i].decode('utf-8') == prev_scenario_id:
106
      turn_nr += 1
107
    else:
108
      turn_nr = 0
109
      prev_scenario_id = x['scenario_id'][i].decode('utf-8')
110

111
    for j in range(0, x['token_seq_length'][i]):
112
      true_enc = encodings.as_prediction_encoding(x['annotation_seq'][i, j, :])
113
      pred_index = x['state_seq_length'][i] + j
114
      if pred_index >= environment.config.max_seq_len:
115
        continue
116
      pred_enc = encodings.as_prediction_encoding(y_pred[i, pred_index, :])
117

118
      if true_enc.enref_meta.is_enref() > 0.0:
119
        word = x['word_seq'][i, j, 0].decode('utf-8')
120

121
        if word in signals.FEMALE_NAMES:
122
          word = 'FFN'
123
        elif word in signals.MALE_NAMES:
124
          word = 'MFN'
125
        elif word not in PRONOUNS:
126
          other_entity_tokens[word] += 1
127
          word = 'OTHER'
128

129
        if word not in stats:
130
          stats[word] = [0, 0, 0]
131

132
        stats[word][0] += 1
133
        stats[true_enc.enref_properties.get_domain() + '/stats'][0] += 1
134
        stats[f'by_turn/{turn_nr}'][0] += 1
135
        if pred_enc.enref_meta.is_enref() > 0.0:
136
          stats[word][1] += 1
137
          stats[true_enc.enref_properties.get_domain() + '/stats'][1] += 1
138
          stats[f'by_turn/{turn_nr}'][1] += 1
139
        if pred_enc.enref_id.get() == true_enc.enref_id.get():
140
          stats[word][2] += 1
141
          stats[true_enc.enref_properties.get_domain() + '/stats'][2] += 1
142
          stats[f'by_turn/{turn_nr}'][2] += 1
143

144
      for m in METRICS:
145
        if (m[0] != 'all' and
146
            m[0] != true_enc.enref_properties.get_domain()):
147
          continue
148
        true_y, logits = _get_named_slices(true_enc, pred_enc, m[1])
149
        pred_y = tf.cast(logits > 0.0, tf.float32)
150

151
        stats[f'{m[0]}/{m[1]}/tp'] += tf.reduce_sum(true_y * pred_y).numpy()
152
        stats[f'{m[0]}/{m[1]}/fp'] += tf.reduce_sum(
153
            (1.0 - true_y) * pred_y).numpy()
154
        stats[f'{m[0]}/{m[1]}/fn'] += tf.reduce_sum(
155
            true_y * (1.0 - pred_y)).numpy()
156

157
  for m in METRICS:
158
    stats[f'{m[0]}/{m[1]}/pr'] = round(
159
        stats[f'{m[0]}/{m[1]}/tp'] /
160
        (stats[f'{m[0]}/{m[1]}/tp'] + stats[f'{m[0]}/{m[1]}/fp'] + EPSILON), 3)
161
    stats[f'{m[0]}/{m[1]}/re'] = round(
162
        stats[f'{m[0]}/{m[1]}/tp'] /
163
        (stats[f'{m[0]}/{m[1]}/tp'] + stats[f'{m[0]}/{m[1]}/fn'] + EPSILON), 3)
164
    stats[f'{m[0]}/{m[1]}/f1'] = round(
165
        2.0 * (stats[f'{m[0]}/{m[1]}/pr'] * stats[f'{m[0]}/{m[1]}/re']) /
166
        (stats[f'{m[0]}/{m[1]}/pr'] + stats[f'{m[0]}/{m[1]}/re'] + EPSILON), 3)
167

168
  return stats, other_entity_tokens
169

170

171
def find_cluster(stats, m1, m2):
172
  """Checks if m1 and m2 are in true or predited cluster."""
173
  in_true_cluster = False
174
  in_pred_cluster = False
175
  for cluster in stats['true_clusters'].values():
176
    if m1 in cluster and m2 in cluster:
177
      in_true_cluster = True
178
      break
179
  for cluster in stats['pred_clusters'].values():
180
    if m1 in cluster and m2 in cluster:
181
      in_pred_cluster = True
182
      break
183
  return in_true_cluster, in_pred_cluster
184

185

186
def _compute_entity_tracking_stats(x, y_pred,
187
                                   environment):
188
  """Computes statistics about accuracy on enrefs in certain categories."""
189
  encodings = environment.encodings
190
  el_stats = {}
191
  for category in ['singular', 'plural', 'both']:
192
    el_stats.update({
193
        f'{category}_true': 0,
194
        f'{category}_pred': 0,
195
        f'{category}_correct': 0,
196
    })
197

198
  scene_stats = {
199
      'id': '',
200
      'm_id': 0,
201
      'true_clusters': collections.defaultdict(set),
202
      'pred_clusters': collections.defaultdict(set)
203
  }
204

205
  blanc_stats = [[0, 0], [0, 0]]
206

207
  for i in range(0, y_pred.shape[0]):
208
    for j in range(0, x['token_seq_length'][i]):
209
      true_enc = encodings.as_prediction_encoding(x['annotation_seq'][i, j, :])
210
      pred_index = x['state_seq_length'][i] + j
211
      if pred_index >= environment.config.max_seq_len:
212
        continue
213
      pred_enc = encodings.as_prediction_encoding(y_pred[i, pred_index, :])
214

215
      # Collect stats for Entity Linking F1 score
216
      true_entities = []
217
      if true_enc.enref_meta.is_enref() > 0.0:
218
        if true_enc.enref_properties.is_group() > 0.0:
219
          true_entities = true_enc.enref_membership.get_ids()
220
          el_stats['plural_true'] += len(true_entities)
221
          el_stats['both_true'] += len(true_entities)
222
        else:
223
          true_entities = [true_enc.enref_id.get()]
224
          el_stats['singular_true'] += 1
225
          el_stats['both_true'] += 1
226

227
      pred_entities = []
228
      if pred_enc.enref_meta.is_enref() > 0.0:
229
        if pred_enc.enref_properties.is_group() > 0.0:
230
          pred_entities = pred_enc.enref_membership.get_ids()
231
          el_stats['plural_pred'] += len(pred_entities)
232
          el_stats['both_pred'] += len(pred_entities)
233
        else:
234
          pred_entities = [pred_enc.enref_id.get()]
235
          el_stats['singular_pred'] += 1
236
          el_stats['both_pred'] += 1
237

238
      for entity in true_entities:
239
        if entity in pred_entities:
240
          el_stats['both_correct'] += 1
241
          if true_enc.enref_properties.is_group() > 0.0:
242
            el_stats['plural_correct'] += 1
243
          else:
244
            el_stats['singular_correct'] += 1
245

246
      # Collect stats for BLANC
247
      scene_id = x['scenario_id'][i]
248
      if not scene_stats['id']:
249
        scene_stats['id'] = scene_id
250
      m_id = scene_stats['m_id']
251
      if scene_id != scene_stats['id']:
252
        for m1 in range(0, m_id):
253
          for m2 in range(0, m1):
254
            in_true_cluster, in_pred_cluster = find_cluster(scene_stats, m1, m2)
255
            blanc_stats[1 - int(in_true_cluster)][1 - int(in_pred_cluster)] += 1
256
        scene_stats = {
257
            'id': scene_id,
258
            'm_id': 0,
259
            'true_clusters': collections.defaultdict(set),
260
            'pred_clusters': collections.defaultdict(set)
261
        }
262

263
      if true_enc.enref_meta.is_enref() > 0.0:
264
        scene_stats['m_id'] += 1
265
        if true_enc.enref_properties.is_group() > 0.0:
266
          for e_id in true_enc.enref_membership.get_ids():
267
            scene_stats['true_clusters'][e_id].add(m_id)
268
        else:
269
          scene_stats['true_clusters'][true_enc.enref_id.get()].add(m_id)
270

271
        if pred_enc.enref_meta.is_enref() > 0.0:
272
          if pred_enc.enref_properties.is_group() > 0.0:
273
            for e_id in pred_enc.enref_membership.get_ids():
274
              scene_stats['pred_clusters'][e_id].add(m_id)
275
          else:
276
            scene_stats['pred_clusters'][pred_enc.enref_id.get()].add(m_id)
277

278
  el_results = {}
279
  for c in ['singular', 'plural', 'both']:
280
    el_results.update({
281
        f'{c}_precision': el_stats[f'{c}_correct'] / el_stats[f'{c}_pred'],
282
        f'{c}_recall': el_stats[f'{c}_correct'] / el_stats[f'{c}_true'],
283
    })
284
    el_results[f'{c}_f1'] = (
285
        2.0 * (el_results[f'{c}_precision'] * el_results[f'{c}_recall']) /
286
        (el_results[f'{c}_precision'] + el_results[f'{c}_recall']))
287

288
  b = blanc_stats
289
  logging.info('B: %s', b)
290
  blanc_results = {
291
      'Pc': b[0][0] / (b[0][0] + b[1][0]),
292
      'Rc': b[0][0] / (b[0][0] + b[0][1]),
293
      'Pn': b[1][1] / (b[1][1] + b[0][1]),
294
      'Rn': b[1][1] / (b[1][1] + b[1][0]),
295
  }
296
  blanc_results['F1c'] = (2.0 * (blanc_results['Pc'] * blanc_results['Rc']) /
297
                          (blanc_results['Pc'] + blanc_results['Rc']))
298
  blanc_results['F1n'] = (2.0 * (blanc_results['Pn'] * blanc_results['Rn']) /
299
                          (blanc_results['Pn'] + blanc_results['Rn']))
300
  blanc_results.update({
301
      'P': (blanc_results['Pc'] + blanc_results['Pn']) / 2.0,
302
      'R': (blanc_results['Rc'] + blanc_results['Rn']) / 2.0,
303
      'F1': (blanc_results['F1c'] + blanc_results['F1n']) / 2.0
304
  })
305

306
  return el_results, blanc_results
307

308

309
def _save_clusters(x, y_pred,
310
                   environment, file_name):
311
  """Saves clusters to jsonlines file."""
312
  encodings = environment.encodings
313
  examples = {}
314
  with tf.io.gfile.GFile(file_name, 'r') as input_file:
315
    for line in input_file:
316
      example = json.loads(line)
317
      examples[example['doc_key']] = example
318

319
  prev_id = x['scenario_id'][0].decode('utf-8')
320
  clusters = {}
321
  num_tokens = 0
322
  prev_enref = None
323
  for i in range(0, y_pred.shape[0]):
324
    s_id = x['scenario_id'][i].decode('utf-8')
325
    if s_id != prev_id:
326
      examples['tc/' + prev_id]['predicted_clusters'] = list(clusters.values())
327
      prev_id = s_id
328
      clusters = {}
329
      num_tokens = 0
330
      prev_enref = None
331

332
    for j in range(0, x['token_seq_length'][i]):
333
      token = x['word_seq'][i, j, 0].decode('utf-8')
334
      if token.startswith('['):
335
        continue
336

337
      pred_index = x['state_seq_length'][i] + j
338
      if pred_index >= environment.config.max_seq_len:
339
        continue
340
      pred_enc = encodings.as_prediction_encoding(y_pred[i, pred_index, :])
341

342
      pred_id = -1
343
      if (pred_enc.enref_meta.is_enref() > 0 and
344
          pred_enc.enref_properties.is_group() <= 0):
345
        pred_id = pred_enc.enref_id.get()
346

347
      if prev_enref and (pred_id == -1 or pred_id != prev_enref[0]):
348
        cl_id = prev_enref[0]
349
        if cl_id not in clusters:
350
          clusters[cl_id] = []
351
        logging.info('Adding enref %s to cluster %s', prev_enref[2],
352
                     clusters[cl_id])
353
        clusters[cl_id].append((prev_enref[1], num_tokens - 1))
354
        prev_enref = None
355
      if pred_id >= 0:
356
        logging.info('Starting new enref: %s (%d)', token, num_tokens)
357
        prev_enref = (pred_id, num_tokens, token)
358

359
      if not token.startswith('##'):
360
        logging.info('%s: %d', token, num_tokens)
361
        num_tokens += 1
362

363
  examples['tc/' + prev_id]['predicted_clusters'] = list(clusters.values())
364

365
  teacher_forcing = 'withtf' if FLAGS.teacher_forcing else 'withouttf'
366
  output_file_name = (os.path.splitext(file_name)[0] + '_predicted' +
367
                      teacher_forcing + '.jsonlines')
368
  logging.info('Writing to %s', output_file_name)
369
  with tf.io.gfile.GFile(output_file_name, 'w') as output_file:
370
    for e in examples.values():
371
      output_file.write(json.dumps(e) + '\n')
372

373

374
def main(argv):
375
  del argv  # Unused.
376

377
  env.Env.init_from_saved_model(FLAGS.model_path)
378
  environment = env.Env.get()
379
  if not FLAGS.teacher_forcing:
380
    environment.config.batch_size = 1
381
  logging.info('Inference with config:\n%s', environment.config)
382

383
  logging.info('Reading data from %s', FLAGS.input_data_glob)
384
  input_data = data.read_eval_data(FLAGS.input_data_glob, environment.config,
385
                                   environment.encodings)
386

387
  with tf.keras.utils.custom_object_scope(model.get_custom_objects()):
388
    contrack_model = tf.keras.models.load_model(FLAGS.model_path)
389

390
  contrack_model.print_predictions = True
391
  if not FLAGS.teacher_forcing:
392
    contrack_model.compile(run_eagerly=True)
393
    contrack_model.disable_teacher_forcing()
394

395
  if FLAGS.eval:
396
    contrack_model.evaluate(
397
        input_data, batch_size=environment.config.batch_size)
398
  else:
399
    x, y_pred = contrack_model.predict(
400
        input_data, batch_size=environment.config.batch_size, verbose=1)
401

402
    stats, other_entities = _compute_stats(x, y_pred, environment)
403
    logging.info('Accuracy Stats:')
404
    for k, v in stats.items():
405
      logging.info('%s: %s', k, v)
406
    logging.info('Other entities: ')
407
    for word, count in sorted(other_entities.items(), key=lambda w: -w[1]):
408
      if count < 5:
409
        break
410
      logging.info('%s: %d', word, count)
411

412
    el_stats, blanc_stats = _compute_entity_tracking_stats(x, y_pred,
413
                                                           environment)
414
    logging.info('entity linking results: %s', str(el_stats))
415
    logging.info('BLANC results: %s', str(blanc_stats))
416

417
    if FLAGS.clusters_file:
418
      _save_clusters(x, y_pred, environment, FLAGS.clusters_file)
419

420

421
if __name__ == '__main__':
422
  app.run(main)
423

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.