datasets

Форк
0
/
perplexity.py 
188 строк · 8.1 Кб
1
# Copyright 2022 The HuggingFace Datasets Authors and the current dataset script contributor.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""Perplexity Metric."""
15

16
import numpy as np
17
import torch
18
from torch.nn import CrossEntropyLoss
19
from transformers import AutoModelForCausalLM, AutoTokenizer
20

21
import datasets
22
from datasets import logging
23

24

25
_CITATION = """\
26

27
"""
28

29
_DESCRIPTION = """
30
Perplexity (PPL) is one of the most common metrics for evaluating language models.
31
It is defined as the exponentiated average negative log-likelihood of a sequence.
32

33
For more information, see https://huggingface.co/docs/transformers/perplexity
34
"""
35

36
_KWARGS_DESCRIPTION = """
37
Args:
38
    model_id (str): model used for calculating Perplexity
39
            NOTE: Perplexity can only be calculated for causal language models.
40
                    This includes models such as gpt2, causal variations of bert,
41
                    causal versions of t5, and more (the full list can be found
42
                    in the AutoModelForCausalLM documentation here:
43
                    https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForCausalLM )
44

45
    input_texts (list of str): input text, each separate text snippet
46
        is one list entry.
47
    batch_size (int): the batch size to run texts through the model. Defaults to 16.
48
    add_start_token (bool): whether to add the start token to the texts,
49
        so the perplexity can include the probability of the first word. Defaults to True.
50
    device (str): device to run on, defaults to 'cuda' when available
51
Returns:
52
    perplexity: dictionary containing the perplexity scores for the texts
53
        in the input list, as well as the mean perplexity. If one of the input texts is
54
        longer than the max input length of the model, then it is truncated to the
55
        max length for the perplexity computation.
56
Examples:
57
    Example 1:
58
        >>> perplexity = datasets.load_metric("perplexity")
59
        >>> input_texts = ["lorem ipsum", "Happy Birthday!", "Bienvenue"]
60
        >>> results = perplexity.compute(model_id='gpt2',
61
        ...                              add_start_token=False,
62
        ...                              input_texts=input_texts) # doctest:+ELLIPSIS
63
        >>> print(list(results.keys()))
64
        ['perplexities', 'mean_perplexity']
65
        >>> print(round(results["mean_perplexity"], 2))
66
        78.22
67
        >>> print(round(results["perplexities"][0], 2))
68
        11.11
69

70
    Example 2:
71
        >>> perplexity = datasets.load_metric("perplexity")
72
        >>> input_texts = datasets.load_dataset("wikitext",
73
        ...                                     "wikitext-2-raw-v1",
74
        ...                                     split="test")["text"][:50]
75
        >>> input_texts = [s for s in input_texts if s!='']
76
        >>> results = perplexity.compute(model_id='gpt2',
77
        ...                              input_texts=input_texts) # doctest:+ELLIPSIS
78
        >>> print(list(results.keys()))
79
        ['perplexities', 'mean_perplexity']
80
        >>> print(round(results["mean_perplexity"], 2))
81
        60.35
82
        >>> print(round(results["perplexities"][0], 2))
83
        81.12
84
"""
85

86

87
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
88
class Perplexity(datasets.Metric):
89
    def _info(self):
90
        return datasets.MetricInfo(
91
            description=_DESCRIPTION,
92
            citation=_CITATION,
93
            inputs_description=_KWARGS_DESCRIPTION,
94
            features=datasets.Features(
95
                {
96
                    "input_texts": datasets.Value("string"),
97
                }
98
            ),
99
            reference_urls=["https://huggingface.co/docs/transformers/perplexity"],
100
        )
101

102
    def _compute(self, input_texts, model_id, batch_size: int = 16, add_start_token: bool = True, device=None):
103
        if device is not None:
104
            assert device in ["gpu", "cpu", "cuda"], "device should be either gpu or cpu."
105
            if device == "gpu":
106
                device = "cuda"
107
        else:
108
            device = "cuda" if torch.cuda.is_available() else "cpu"
109

110
        model = AutoModelForCausalLM.from_pretrained(model_id)
111
        model = model.to(device)
112

113
        tokenizer = AutoTokenizer.from_pretrained(model_id)
114

115
        # if batch_size > 1 (which generally leads to padding being required), and
116
        # if there is not an already assigned pad_token, assign an existing
117
        # special token to also be the padding token
118
        if tokenizer.pad_token is None and batch_size > 1:
119
            existing_special_tokens = list(tokenizer.special_tokens_map_extended.values())
120
            # check that the model already has at least one special token defined
121
            assert (
122
                len(existing_special_tokens) > 0
123
            ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1."
124
            # assign one of the special tokens to also be the pad token
125
            tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})
126

127
        if add_start_token:
128
            # leave room for <BOS> token to be added:
129
            assert (
130
                tokenizer.bos_token is not None
131
            ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
132
            max_tokenized_len = model.config.max_length - 1
133
        else:
134
            max_tokenized_len = model.config.max_length
135

136
        encodings = tokenizer(
137
            input_texts,
138
            add_special_tokens=False,
139
            padding=True,
140
            truncation=True,
141
            max_length=max_tokenized_len,
142
            return_tensors="pt",
143
            return_attention_mask=True,
144
        ).to(device)
145

146
        encoded_texts = encodings["input_ids"]
147
        attn_masks = encodings["attention_mask"]
148

149
        # check that each input is long enough:
150
        if add_start_token:
151
            assert torch.all(torch.ge(attn_masks.sum(1), 1)), "Each input text must be at least one token long."
152
        else:
153
            assert torch.all(
154
                torch.ge(attn_masks.sum(1), 2)
155
            ), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings."
156

157
        ppls = []
158
        loss_fct = CrossEntropyLoss(reduction="none")
159

160
        for start_index in logging.tqdm(range(0, len(encoded_texts), batch_size)):
161
            end_index = min(start_index + batch_size, len(encoded_texts))
162
            encoded_batch = encoded_texts[start_index:end_index]
163
            attn_mask = attn_masks[start_index:end_index]
164

165
            if add_start_token:
166
                bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)).to(device)
167
                encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1)
168
                attn_mask = torch.cat(
169
                    [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_mask], dim=1
170
                )
171

172
            labels = encoded_batch
173

174
            with torch.no_grad():
175
                out_logits = model(encoded_batch, attention_mask=attn_mask).logits
176

177
            shift_logits = out_logits[..., :-1, :].contiguous()
178
            shift_labels = labels[..., 1:].contiguous()
179
            shift_attention_mask_batch = attn_mask[..., 1:].contiguous()
180

181
            perplexity_batch = torch.exp2(
182
                (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1)
183
                / shift_attention_mask_batch.sum(1)
184
            )
185

186
            ppls += perplexity_batch.tolist()
187

188
        return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)}
189

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.