datasets

Форк
0
/
bertscore.py 
207 строк · 7.9 Кб
1
# Copyright 2020 The HuggingFace Datasets Authors.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""BERTScore metric."""
15

16
import functools
17
from contextlib import contextmanager
18

19
import bert_score
20
from packaging import version
21

22
import datasets
23

24

25
@contextmanager
26
def filter_logging_context():
27
    def filter_log(record):
28
        return False if "This IS expected if you are initializing" in record.msg else True
29

30
    logger = datasets.utils.logging.get_logger("transformers.modeling_utils")
31
    logger.addFilter(filter_log)
32
    try:
33
        yield
34
    finally:
35
        logger.removeFilter(filter_log)
36

37

38
_CITATION = """\
39
@inproceedings{bert-score,
40
  title={BERTScore: Evaluating Text Generation with BERT},
41
  author={Tianyi Zhang* and Varsha Kishore* and Felix Wu* and Kilian Q. Weinberger and Yoav Artzi},
42
  booktitle={International Conference on Learning Representations},
43
  year={2020},
44
  url={https://openreview.net/forum?id=SkeHuCVFDr}
45
}
46
"""
47

48
_DESCRIPTION = """\
49
BERTScore leverages the pre-trained contextual embeddings from BERT and matches words in candidate and reference
50
sentences by cosine similarity.
51
It has been shown to correlate with human judgment on sentence-level and system-level evaluation.
52
Moreover, BERTScore computes precision, recall, and F1 measure, which can be useful for evaluating different language
53
generation tasks.
54

55
See the project's README at https://github.com/Tiiiger/bert_score#readme for more information.
56
"""
57

58
_KWARGS_DESCRIPTION = """
59
BERTScore Metrics with the hashcode from a source against one or more references.
60

61
Args:
62
    predictions (list of str): Prediction/candidate sentences.
63
    references (list of str or list of list of str): Reference sentences.
64
    lang (str): Language of the sentences; required (e.g. 'en').
65
    model_type (str): Bert specification, default using the suggested
66
        model for the target language; has to specify at least one of
67
        `model_type` or `lang`.
68
    num_layers (int): The layer of representation to use,
69
        default using the number of layers tuned on WMT16 correlation data.
70
    verbose (bool): Turn on intermediate status update.
71
    idf (bool or dict): Use idf weighting; can also be a precomputed idf_dict.
72
    device (str): On which the contextual embedding model will be allocated on.
73
        If this argument is None, the model lives on cuda:0 if cuda is available.
74
    nthreads (int): Number of threads.
75
    batch_size (int): Bert score processing batch size,
76
        at least one of `model_type` or `lang`. `lang` needs to be
77
        specified when `rescale_with_baseline` is True.
78
    rescale_with_baseline (bool): Rescale bertscore with pre-computed baseline.
79
    baseline_path (str): Customized baseline file.
80
    use_fast_tokenizer (bool): `use_fast` parameter passed to HF tokenizer. New in version 0.3.10.
81

82
Returns:
83
    precision: Precision.
84
    recall: Recall.
85
    f1: F1 score.
86
    hashcode: Hashcode of the library.
87

88
Examples:
89

90
    >>> predictions = ["hello there", "general kenobi"]
91
    >>> references = ["hello there", "general kenobi"]
92
    >>> bertscore = datasets.load_metric("bertscore")
93
    >>> results = bertscore.compute(predictions=predictions, references=references, lang="en")
94
    >>> print([round(v, 2) for v in results["f1"]])
95
    [1.0, 1.0]
96
"""
97

98

99
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
100
class BERTScore(datasets.Metric):
101
    def _info(self):
102
        return datasets.MetricInfo(
103
            description=_DESCRIPTION,
104
            citation=_CITATION,
105
            homepage="https://github.com/Tiiiger/bert_score",
106
            inputs_description=_KWARGS_DESCRIPTION,
107
            features=datasets.Features(
108
                {
109
                    "predictions": datasets.Value("string", id="sequence"),
110
                    "references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
111
                }
112
            ),
113
            codebase_urls=["https://github.com/Tiiiger/bert_score"],
114
            reference_urls=[
115
                "https://github.com/Tiiiger/bert_score",
116
                "https://arxiv.org/abs/1904.09675",
117
            ],
118
        )
119

120
    def _compute(
121
        self,
122
        predictions,
123
        references,
124
        lang=None,
125
        model_type=None,
126
        num_layers=None,
127
        verbose=False,
128
        idf=False,
129
        device=None,
130
        batch_size=64,
131
        nthreads=4,
132
        all_layers=False,
133
        rescale_with_baseline=False,
134
        baseline_path=None,
135
        use_fast_tokenizer=False,
136
    ):
137
        get_hash = bert_score.utils.get_hash
138
        scorer = bert_score.BERTScorer
139

140
        if version.parse(bert_score.__version__) >= version.parse("0.3.10"):
141
            get_hash = functools.partial(get_hash, use_fast_tokenizer=use_fast_tokenizer)
142
            scorer = functools.partial(scorer, use_fast_tokenizer=use_fast_tokenizer)
143
        elif use_fast_tokenizer:
144
            raise ImportWarning(
145
                "To use a fast tokenizer, the module `bert-score>=0.3.10` is required, and the current version of `bert-score` doesn't match this condition.\n"
146
                'You can install it with `pip install "bert-score>=0.3.10"`.'
147
            )
148

149
        if model_type is None:
150
            assert lang is not None, "either lang or model_type should be specified"
151
            model_type = bert_score.utils.lang2model[lang.lower()]
152

153
        if num_layers is None:
154
            num_layers = bert_score.utils.model2layers[model_type]
155

156
        hashcode = get_hash(
157
            model=model_type,
158
            num_layers=num_layers,
159
            idf=idf,
160
            rescale_with_baseline=rescale_with_baseline,
161
            use_custom_baseline=baseline_path is not None,
162
        )
163

164
        with filter_logging_context():
165
            if not hasattr(self, "cached_bertscorer") or self.cached_bertscorer.hash != hashcode:
166
                self.cached_bertscorer = scorer(
167
                    model_type=model_type,
168
                    num_layers=num_layers,
169
                    batch_size=batch_size,
170
                    nthreads=nthreads,
171
                    all_layers=all_layers,
172
                    idf=idf,
173
                    device=device,
174
                    lang=lang,
175
                    rescale_with_baseline=rescale_with_baseline,
176
                    baseline_path=baseline_path,
177
                )
178

179
        (P, R, F) = self.cached_bertscorer.score(
180
            cands=predictions,
181
            refs=references,
182
            verbose=verbose,
183
            batch_size=batch_size,
184
        )
185
        output_dict = {
186
            "precision": P.tolist(),
187
            "recall": R.tolist(),
188
            "f1": F.tolist(),
189
            "hashcode": hashcode,
190
        }
191
        return output_dict
192

193
    def add_batch(self, predictions=None, references=None, **kwargs):
194
        """Add a batch of predictions and references for the metric's stack."""
195
        # References can be strings or lists of strings
196
        # Let's change strings to lists of strings with one element
197
        if references is not None:
198
            references = [[ref] if isinstance(ref, str) else ref for ref in references]
199
        super().add_batch(predictions=predictions, references=references, **kwargs)
200

201
    def add(self, prediction=None, reference=None, **kwargs):
202
        """Add one prediction and reference for the metric's stack."""
203
        # References can be strings or lists of strings
204
        # Let's change strings to lists of strings with one element
205
        if isinstance(reference, str):
206
            reference = [reference]
207
        super().add(prediction=prediction, reference=reference, **kwargs)
208

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.