google-research
320 строк · 9.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Evaluation library for split+rephrase sentence decompostion."""
17
18from __future__ import absolute_import19from __future__ import division20from __future__ import print_function21
22import itertools23from absl import logging24from nltk.translate import bleu_score as nltk_bleu_score25import numpy as np26
27
28def MacroAvgSentBLEU(ref_str_lists, hyp_strs):29"""Compute multi-reference BLEU (macro-averaged over sentences) using NLTK.30
31Contents must already be split into tokens.
32
33Args:
34ref_str_lists: list of reference lists
35hyp_strs: list of hypothesis strings
36
37Returns:
38(float) BLEU score
39"""
40assert len(hyp_strs) == len(ref_str_lists)41scores = []42sentence_bleu_fn = MaybeEmulateMultiBleu(nltk_bleu_score.sentence_bleu)43for references, hypothesis in zip(ref_str_lists, hyp_strs):44scores.append(sentence_bleu_fn(references, hypothesis))45return np.mean(scores)46
47
48def GetTokenLists(ref_str_lists, hyp_strs, tokenize_fn=lambda s: s.split()):49"""Split tokenized strings into lists of tokens.50
51Args:
52ref_str_lists: list(list(str)) of multi-reference items
53hyp_strs: list(str) hypotheses
54tokenize_fn: a function that splits a string into a list of tokens.
55
56Returns:
57references: tokenized references as a list(list(list(str))).
58hypotheses: tokenized hypotheses as a list(list(str)).
59"""
60
61ref_str_lists_tokenized = [62list(map(tokenize_fn, ref_list)) for ref_list in ref_str_lists63]64hyp_strs_tokenized = list(map(tokenize_fn, hyp_strs))65
66return ref_str_lists_tokenized, hyp_strs_tokenized67
68
69def ReadParcels(line_iterator,70parcel_sep='<::::>',71reduce_to_single_analysis=False):72r"""Parse one or more decompositions from each line in line_iterator.73
74Each input item is split by tab and then by parcel_sep.
75
76Args:
77line_iterator: iterable over strings of the format
78First parcel . <::::> Second parcel.\tOther <:::> option .
79parcel_sep: the string symbol between two simple sentences (parcels).
80reduce_to_single_analysis: if True, assume each line has a single
81decomposition and reduce the return type accordingly (see below).
82
83Returns:
84parceled_instances: a list(list(list(str))). The example string above would
85yield the list item
86[["First parcel .", "Second parcel."], ["Other", "option ."]]
87When reduce_to_single_analysis=True, one dimension is stripped out such
88that the return value is a list(list(str)).
89"""
90
91def SplitParcels(analysis):92"""Split one analysis string into list of non-empty parcels."""93parcels = [parcel.strip() for parcel in analysis.split(parcel_sep)]94return [p for p in parcels if p]95
96# Parse input lines to multi-analysis parcel lists.97parceled_instances = []98for line in line_iterator:99analyses = line.strip().split('\t')100assert analyses101parceled_instances.append([SplitParcels(analysis) for analysis in analyses])102
103if reduce_to_single_analysis:104assert all([len(analyses) == 1 for analyses in parceled_instances])105parceled_instances = [analyses[0] for analyses in parceled_instances]106
107return parceled_instances108
109
110def MaybeEmulateMultiBleu(nltk_target_fn):111"""Includes emulate_multibleu argument into nltk_target_fn if necessary.112
113The signature of the NLTK functions corpus_bleu and sentence_bleu depend on
114the NLTK version. This function works around version differences encountered
115in the public and internal environments.
116
117Args:
118nltk_target_fn: a function that computes BLEU given arguments gold and
119predicted.
120
121Returns:
122a function that takes arguments gold and predicted, in the format
123expected by NLTK's corpus_bleu and sentence_bleu functions.
124"""
125fn = nltk_target_fn126
127return fn128
129
130def ComputeMetrics(pred, gold):131"""Calculates metrics and returns scores as a dict.132
133Computes the following metrics:
134- corpus-level BLEU
135- multi-reference, the standard way.
136- macro-averaged
137- sentence-level BLEU
138
139Args:
140pred: hypotheses as a list of strings
141gold: references as list of list of strings
142
143Returns:
144dict(string -> float) metrics
145"""
146results = {}147tok_gold, tok_pred = GetTokenLists(gold, pred)148
149# Legacy tag.150field = 'decomp'151
152# Sentence-level BLEU.153macro_avg_sent_bleu = MacroAvgSentBLEU(tok_gold, tok_pred) * 100.0154results['bleu.macro_avg_sent.' + field] = macro_avg_sent_bleu155
156# Corpus-level BLEU.157corpus_bleu_fn = MaybeEmulateMultiBleu(nltk_bleu_score.corpus_bleu)158corpus_bleu = corpus_bleu_fn(tok_gold, tok_pred) * 100.0159results['bleu.corpus.' + field] = corpus_bleu160logging.info('BLEU %s: %05.02f', field, corpus_bleu)161
162return results163
164
165def NumTokens(s):166return len(s.split())167
168
169def LengthStatistics(data):170"""Updates results with simple length-based statistics.171
172parcels / input_sentence - (S/C metric in paper) macro averaged num
173tokens per parcel (Tokens/S in paper)
174
175Example of an item in data: ['parcel1 here .', 'parcel 2 here .']
176
177Args:
178data: list of parcel lists.
179
180Returns:
181dictionary of results
182"""
183
184results = {}185
186# Average number of parcels per decomposed instance.187parcel_counts = [len(instance) for instance in data]188results['lengths.simple_per_complex'] = np.mean(parcel_counts)189
190# Token counts.191token_counts = []192for instance in data:193token_counts.append([NumTokens(parcel) for parcel in instance])194
195# Macro averaged number of tokens per parcel.196results['lengths.tokens_per_simple'] = np.mean(197[np.mean(counts) for counts in token_counts])198
199# Micro averaged number of tokens per parcel.200total_tokens = np.sum(list(itertools.chain.from_iterable(token_counts)))201total_parcels = np.sum(parcel_counts)202results['lengths.tokens_per_simple_micro'] = total_tokens / total_parcels203
204return results205
206
207def GoldLengthStatistics(data):208"""Updates results with simple length-based statistics over multi-ref data.209
210Example of an item in data: [['parcel1 here .', 'parcel 2 here .'], [alt..]]
211
212Args:
213data: list of list of parcel lists.
214
215Returns:
216dictionary of results
217"""
218
219results = {}220
221# Macro-average number of parcels per decomposed instance.222parcel_counts = []223for instance in data:224parcel_counts.append([len(analysis) for analysis in instance])225
226results['ref_lengths.simple_per_complex'] = np.mean(227[np.mean(counts) for counts in parcel_counts])228
229# Token counts.230token_counts = []231for instance in data:232instance_counts = []233for analysis in instance:234instance_counts.append([NumTokens(parcel) for parcel in analysis])235token_counts.append(instance_counts)236
237# Macro averaged number of tokens per parcel.238token_means_per_analysis = []239for instance in token_counts:240token_means_per_analysis.append(241[np.mean(analysis_counts) for analysis_counts in instance])242
243results['ref_lengths.tokens_per_simple'] = np.mean(244[np.mean(counts) for counts in token_means_per_analysis])245
246return results247
248
249def PerformEval(gold, pred, debug=False):250"""Runs evaluation of predictions relative to references.251
252Args:
253gold: gold references; each item is a list of one or more analyses, and each
254analysis is a list of parcel strings.
255pred: system predictions as a list of parcel lists.
256debug: debug mode prints out sample of data.
257
258Returns:
259dictionary of results
260"""
261
262logging.info('Gold labels: read %d rows', len(gold))263logging.info('Predicted labels: read %d rows', len(pred))264
265assert len(gold) == len(pred), (266'Got unequal number of gold items ({}) and predictions ({})'.format(267len(gold), len(pred)))268
269if debug:270print(gold[:2])271print(pred[:2])272
273results = {}274
275# Calculate some stats on predictions.276results.update(LengthStatistics(pred))277results.update(GoldLengthStatistics(gold))278
279# Collapse each analysis from a list of parcels into a single string,280# since that is what we calculate metrics over.281gold_decompositions = []282for gold_instance in gold:283gold_decompositions.append(284[' '.join(parcel_list) for parcel_list in gold_instance])285
286pred_decompositions = [' '.join(parcel_list) for parcel_list in pred]287if debug:288print(gold_decompositions[:2])289print(pred_decompositions[:2])290
291# Number of unique references per input.292counts = [293len(set(instance_references))294for instance_references in gold_decompositions295]296results['uniq_refs_per_input.avg'] = np.mean(counts)297results['uniq_refs_per_input.min'] = np.min(counts)298results['uniq_refs_per_input.max'] = np.max(counts)299
300# Number of references per input.301counts = [302len(instance_references) for instance_references in gold_decompositions303]304results['refs_per_input.avg'] = np.mean(counts)305results['refs_per_input.min'] = np.min(counts)306results['refs_per_input.max'] = np.max(counts)307
308# Number of items in input data.309results['counts.pred_inputs'] = len(pred)310results['counts.gold_inputs'] = len(gold)311
312# Number of individual items in input data (across analyses)313results['counts.references'] = len(list(itertools.chain.from_iterable(gold)))314results['counts.predictions'] = len(pred)315
316# Calculate scoring metrics.317results.update(318ComputeMetrics(pred=pred_decompositions, gold=gold_decompositions))319
320return results321