1
# Copyright 2020 The HuggingFace Datasets Authors.
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
7
# http://www.apache.org/licenses/LICENSE-2.0
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""BERTScore metric."""
17
from contextlib import contextmanager
20
from packaging import version
26
def filter_logging_context():
27
def filter_log(record):
28
return False if "This IS expected if you are initializing" in record.msg else True
30
logger = datasets.utils.logging.get_logger("transformers.modeling_utils")
31
logger.addFilter(filter_log)
35
logger.removeFilter(filter_log)
39
@inproceedings{bert-score,
40
title={BERTScore: Evaluating Text Generation with BERT},
41
author={Tianyi Zhang* and Varsha Kishore* and Felix Wu* and Kilian Q. Weinberger and Yoav Artzi},
42
booktitle={International Conference on Learning Representations},
44
url={https://openreview.net/forum?id=SkeHuCVFDr}
49
BERTScore leverages the pre-trained contextual embeddings from BERT and matches words in candidate and reference
50
sentences by cosine similarity.
51
It has been shown to correlate with human judgment on sentence-level and system-level evaluation.
52
Moreover, BERTScore computes precision, recall, and F1 measure, which can be useful for evaluating different language
55
See the project's README at https://github.com/Tiiiger/bert_score#readme for more information.
58
_KWARGS_DESCRIPTION = """
59
BERTScore Metrics with the hashcode from a source against one or more references.
62
predictions (list of str): Prediction/candidate sentences.
63
references (list of str or list of list of str): Reference sentences.
64
lang (str): Language of the sentences; required (e.g. 'en').
65
model_type (str): Bert specification, default using the suggested
66
model for the target language; has to specify at least one of
67
`model_type` or `lang`.
68
num_layers (int): The layer of representation to use,
69
default using the number of layers tuned on WMT16 correlation data.
70
verbose (bool): Turn on intermediate status update.
71
idf (bool or dict): Use idf weighting; can also be a precomputed idf_dict.
72
device (str): On which the contextual embedding model will be allocated on.
73
If this argument is None, the model lives on cuda:0 if cuda is available.
74
nthreads (int): Number of threads.
75
batch_size (int): Bert score processing batch size,
76
at least one of `model_type` or `lang`. `lang` needs to be
77
specified when `rescale_with_baseline` is True.
78
rescale_with_baseline (bool): Rescale bertscore with pre-computed baseline.
79
baseline_path (str): Customized baseline file.
80
use_fast_tokenizer (bool): `use_fast` parameter passed to HF tokenizer. New in version 0.3.10.
86
hashcode: Hashcode of the library.
90
>>> predictions = ["hello there", "general kenobi"]
91
>>> references = ["hello there", "general kenobi"]
92
>>> bertscore = datasets.load_metric("bertscore")
93
>>> results = bertscore.compute(predictions=predictions, references=references, lang="en")
94
>>> print([round(v, 2) for v in results["f1"]])
99
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
100
class BERTScore(datasets.Metric):
102
return datasets.MetricInfo(
103
description=_DESCRIPTION,
105
homepage="https://github.com/Tiiiger/bert_score",
106
inputs_description=_KWARGS_DESCRIPTION,
107
features=datasets.Features(
109
"predictions": datasets.Value("string", id="sequence"),
110
"references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
113
codebase_urls=["https://github.com/Tiiiger/bert_score"],
115
"https://github.com/Tiiiger/bert_score",
116
"https://arxiv.org/abs/1904.09675",
133
rescale_with_baseline=False,
135
use_fast_tokenizer=False,
137
get_hash = bert_score.utils.get_hash
138
scorer = bert_score.BERTScorer
140
if version.parse(bert_score.__version__) >= version.parse("0.3.10"):
141
get_hash = functools.partial(get_hash, use_fast_tokenizer=use_fast_tokenizer)
142
scorer = functools.partial(scorer, use_fast_tokenizer=use_fast_tokenizer)
143
elif use_fast_tokenizer:
145
"To use a fast tokenizer, the module `bert-score>=0.3.10` is required, and the current version of `bert-score` doesn't match this condition.\n"
146
'You can install it with `pip install "bert-score>=0.3.10"`.'
149
if model_type is None:
150
assert lang is not None, "either lang or model_type should be specified"
151
model_type = bert_score.utils.lang2model[lang.lower()]
153
if num_layers is None:
154
num_layers = bert_score.utils.model2layers[model_type]
158
num_layers=num_layers,
160
rescale_with_baseline=rescale_with_baseline,
161
use_custom_baseline=baseline_path is not None,
164
with filter_logging_context():
165
if not hasattr(self, "cached_bertscorer") or self.cached_bertscorer.hash != hashcode:
166
self.cached_bertscorer = scorer(
167
model_type=model_type,
168
num_layers=num_layers,
169
batch_size=batch_size,
171
all_layers=all_layers,
175
rescale_with_baseline=rescale_with_baseline,
176
baseline_path=baseline_path,
179
(P, R, F) = self.cached_bertscorer.score(
183
batch_size=batch_size,
186
"precision": P.tolist(),
187
"recall": R.tolist(),
189
"hashcode": hashcode,
193
def add_batch(self, predictions=None, references=None, **kwargs):
194
"""Add a batch of predictions and references for the metric's stack."""
195
# References can be strings or lists of strings
196
# Let's change strings to lists of strings with one element
197
if references is not None:
198
references = [[ref] if isinstance(ref, str) else ref for ref in references]
199
super().add_batch(predictions=predictions, references=references, **kwargs)
201
def add(self, prediction=None, reference=None, **kwargs):
202
"""Add one prediction and reference for the metric's stack."""
203
# References can be strings or lists of strings
204
# Let's change strings to lists of strings with one element
205
if isinstance(reference, str):
206
reference = [reference]
207
super().add(prediction=prediction, reference=reference, **kwargs)