google-research

Форк
0
320 строк · 9.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Evaluation library for split+rephrase sentence decompostion."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import itertools
23
from absl import logging
24
from nltk.translate import bleu_score as nltk_bleu_score
25
import numpy as np
26

27

28
def MacroAvgSentBLEU(ref_str_lists, hyp_strs):
29
  """Compute multi-reference BLEU (macro-averaged over sentences) using NLTK.
30

31
  Contents must already be split into tokens.
32

33
  Args:
34
    ref_str_lists: list of reference lists
35
    hyp_strs: list of hypothesis strings
36

37
  Returns:
38
    (float) BLEU score
39
  """
40
  assert len(hyp_strs) == len(ref_str_lists)
41
  scores = []
42
  sentence_bleu_fn = MaybeEmulateMultiBleu(nltk_bleu_score.sentence_bleu)
43
  for references, hypothesis in zip(ref_str_lists, hyp_strs):
44
    scores.append(sentence_bleu_fn(references, hypothesis))
45
  return np.mean(scores)
46

47

48
def GetTokenLists(ref_str_lists, hyp_strs, tokenize_fn=lambda s: s.split()):
49
  """Split tokenized strings into lists of tokens.
50

51
  Args:
52
    ref_str_lists: list(list(str)) of multi-reference items
53
    hyp_strs: list(str) hypotheses
54
    tokenize_fn: a function that splits a string into a list of tokens.
55

56
  Returns:
57
    references: tokenized references as a list(list(list(str))).
58
    hypotheses: tokenized hypotheses as a list(list(str)).
59
  """
60

61
  ref_str_lists_tokenized = [
62
      list(map(tokenize_fn, ref_list)) for ref_list in ref_str_lists
63
  ]
64
  hyp_strs_tokenized = list(map(tokenize_fn, hyp_strs))
65

66
  return ref_str_lists_tokenized, hyp_strs_tokenized
67

68

69
def ReadParcels(line_iterator,
70
                parcel_sep='<::::>',
71
                reduce_to_single_analysis=False):
72
  r"""Parse one or more decompositions from each line in line_iterator.
73

74
  Each input item is split by tab and then by parcel_sep.
75

76
  Args:
77
    line_iterator: iterable over strings of the format
78
      First parcel . <::::> Second parcel.\tOther <:::> option .
79
    parcel_sep: the string symbol between two simple sentences (parcels).
80
    reduce_to_single_analysis: if True, assume each line has a single
81
      decomposition and reduce the return type accordingly (see below).
82

83
  Returns:
84
    parceled_instances: a list(list(list(str))). The example string above would
85
      yield the list item
86
        [["First parcel .", "Second parcel."], ["Other", "option ."]]
87
      When reduce_to_single_analysis=True, one dimension is stripped out such
88
      that the return value is a list(list(str)).
89
  """
90

91
  def SplitParcels(analysis):
92
    """Split one analysis string into list of non-empty parcels."""
93
    parcels = [parcel.strip() for parcel in analysis.split(parcel_sep)]
94
    return [p for p in parcels if p]
95

96
  # Parse input lines to multi-analysis parcel lists.
97
  parceled_instances = []
98
  for line in line_iterator:
99
    analyses = line.strip().split('\t')
100
    assert analyses
101
    parceled_instances.append([SplitParcels(analysis) for analysis in analyses])
102

103
  if reduce_to_single_analysis:
104
    assert all([len(analyses) == 1 for analyses in parceled_instances])
105
    parceled_instances = [analyses[0] for analyses in parceled_instances]
106

107
  return parceled_instances
108

109

110
def MaybeEmulateMultiBleu(nltk_target_fn):
111
  """Includes emulate_multibleu argument into nltk_target_fn if necessary.
112

113
  The signature of the NLTK functions corpus_bleu and sentence_bleu depend on
114
  the NLTK version. This function works around version differences encountered
115
  in the public and internal environments.
116

117
  Args:
118
    nltk_target_fn: a function that computes BLEU given arguments gold and
119
      predicted.
120

121
  Returns:
122
    a function that takes arguments gold and predicted, in the format
123
    expected by NLTK's corpus_bleu and sentence_bleu functions.
124
  """
125
  fn = nltk_target_fn
126

127
  return fn
128

129

130
def ComputeMetrics(pred, gold):
131
  """Calculates metrics and returns scores as a dict.
132

133
  Computes the following metrics:
134
    - corpus-level BLEU
135
      - multi-reference, the standard way.
136
    - macro-averaged
137
      - sentence-level BLEU
138

139
  Args:
140
    pred: hypotheses as a list of strings
141
    gold: references as list of list of strings
142

143
  Returns:
144
    dict(string -> float) metrics
145
  """
146
  results = {}
147
  tok_gold, tok_pred = GetTokenLists(gold, pred)
148

149
  # Legacy tag.
150
  field = 'decomp'
151

152
  # Sentence-level BLEU.
153
  macro_avg_sent_bleu = MacroAvgSentBLEU(tok_gold, tok_pred) * 100.0
154
  results['bleu.macro_avg_sent.' + field] = macro_avg_sent_bleu
155

156
  # Corpus-level BLEU.
157
  corpus_bleu_fn = MaybeEmulateMultiBleu(nltk_bleu_score.corpus_bleu)
158
  corpus_bleu = corpus_bleu_fn(tok_gold, tok_pred) * 100.0
159
  results['bleu.corpus.' + field] = corpus_bleu
160
  logging.info('BLEU %s: %05.02f', field, corpus_bleu)
161

162
  return results
163

164

165
def NumTokens(s):
166
  return len(s.split())
167

168

169
def LengthStatistics(data):
170
  """Updates results with simple length-based statistics.
171

172
  parcels / input_sentence - (S/C metric in paper) macro averaged num
173
  tokens per parcel (Tokens/S in paper)
174

175
  Example of an item in data: ['parcel1 here .', 'parcel 2 here .']
176

177
  Args:
178
    data: list of parcel lists.
179

180
  Returns:
181
    dictionary of results
182
  """
183

184
  results = {}
185

186
  # Average number of parcels per decomposed instance.
187
  parcel_counts = [len(instance) for instance in data]
188
  results['lengths.simple_per_complex'] = np.mean(parcel_counts)
189

190
  # Token counts.
191
  token_counts = []
192
  for instance in data:
193
    token_counts.append([NumTokens(parcel) for parcel in instance])
194

195
  # Macro averaged number of tokens per parcel.
196
  results['lengths.tokens_per_simple'] = np.mean(
197
      [np.mean(counts) for counts in token_counts])
198

199
  # Micro averaged number of tokens per parcel.
200
  total_tokens = np.sum(list(itertools.chain.from_iterable(token_counts)))
201
  total_parcels = np.sum(parcel_counts)
202
  results['lengths.tokens_per_simple_micro'] = total_tokens / total_parcels
203

204
  return results
205

206

207
def GoldLengthStatistics(data):
208
  """Updates results with simple length-based statistics over multi-ref data.
209

210
  Example of an item in data: [['parcel1 here .', 'parcel 2 here .'], [alt..]]
211

212
  Args:
213
    data: list of list of parcel lists.
214

215
  Returns:
216
    dictionary of results
217
  """
218

219
  results = {}
220

221
  # Macro-average number of parcels per decomposed instance.
222
  parcel_counts = []
223
  for instance in data:
224
    parcel_counts.append([len(analysis) for analysis in instance])
225

226
  results['ref_lengths.simple_per_complex'] = np.mean(
227
      [np.mean(counts) for counts in parcel_counts])
228

229
  # Token counts.
230
  token_counts = []
231
  for instance in data:
232
    instance_counts = []
233
    for analysis in instance:
234
      instance_counts.append([NumTokens(parcel) for parcel in analysis])
235
    token_counts.append(instance_counts)
236

237
  # Macro averaged number of tokens per parcel.
238
  token_means_per_analysis = []
239
  for instance in token_counts:
240
    token_means_per_analysis.append(
241
        [np.mean(analysis_counts) for analysis_counts in instance])
242

243
  results['ref_lengths.tokens_per_simple'] = np.mean(
244
      [np.mean(counts) for counts in token_means_per_analysis])
245

246
  return results
247

248

249
def PerformEval(gold, pred, debug=False):
250
  """Runs evaluation of predictions relative to references.
251

252
  Args:
253
    gold: gold references; each item is a list of one or more analyses, and each
254
      analysis is a list of parcel strings.
255
    pred: system predictions as a list of parcel lists.
256
    debug: debug mode prints out sample of data.
257

258
  Returns:
259
    dictionary of results
260
  """
261

262
  logging.info('Gold labels: read %d rows', len(gold))
263
  logging.info('Predicted labels: read %d rows', len(pred))
264

265
  assert len(gold) == len(pred), (
266
      'Got unequal number of gold items ({}) and predictions ({})'.format(
267
          len(gold), len(pred)))
268

269
  if debug:
270
    print(gold[:2])
271
    print(pred[:2])
272

273
  results = {}
274

275
  # Calculate some stats on predictions.
276
  results.update(LengthStatistics(pred))
277
  results.update(GoldLengthStatistics(gold))
278

279
  # Collapse each analysis from a list of parcels into a single string,
280
  # since that is what we calculate metrics over.
281
  gold_decompositions = []
282
  for gold_instance in gold:
283
    gold_decompositions.append(
284
        [' '.join(parcel_list) for parcel_list in gold_instance])
285

286
  pred_decompositions = [' '.join(parcel_list) for parcel_list in pred]
287
  if debug:
288
    print(gold_decompositions[:2])
289
    print(pred_decompositions[:2])
290

291
  # Number of unique references per input.
292
  counts = [
293
      len(set(instance_references))
294
      for instance_references in gold_decompositions
295
  ]
296
  results['uniq_refs_per_input.avg'] = np.mean(counts)
297
  results['uniq_refs_per_input.min'] = np.min(counts)
298
  results['uniq_refs_per_input.max'] = np.max(counts)
299

300
  # Number of references per input.
301
  counts = [
302
      len(instance_references) for instance_references in gold_decompositions
303
  ]
304
  results['refs_per_input.avg'] = np.mean(counts)
305
  results['refs_per_input.min'] = np.min(counts)
306
  results['refs_per_input.max'] = np.max(counts)
307

308
  # Number of items in input data.
309
  results['counts.pred_inputs'] = len(pred)
310
  results['counts.gold_inputs'] = len(gold)
311

312
  # Number of individual items in input data (across analyses)
313
  results['counts.references'] = len(list(itertools.chain.from_iterable(gold)))
314
  results['counts.predictions'] = len(pred)
315

316
  # Calculate scoring metrics.
317
  results.update(
318
      ComputeMetrics(pred=pred_decompositions, gold=gold_decompositions))
319

320
  return results
321

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.