datasets
188 строк · 8.1 Кб
1# Copyright 2022 The HuggingFace Datasets Authors and the current dataset script contributor.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Perplexity Metric."""
15
16import numpy as np17import torch18from torch.nn import CrossEntropyLoss19from transformers import AutoModelForCausalLM, AutoTokenizer20
21import datasets22from datasets import logging23
24
25_CITATION = """\26
27"""
28
29_DESCRIPTION = """30Perplexity (PPL) is one of the most common metrics for evaluating language models.
31It is defined as the exponentiated average negative log-likelihood of a sequence.
32
33For more information, see https://huggingface.co/docs/transformers/perplexity
34"""
35
36_KWARGS_DESCRIPTION = """37Args:
38model_id (str): model used for calculating Perplexity
39NOTE: Perplexity can only be calculated for causal language models.
40This includes models such as gpt2, causal variations of bert,
41causal versions of t5, and more (the full list can be found
42in the AutoModelForCausalLM documentation here:
43https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForCausalLM )
44
45input_texts (list of str): input text, each separate text snippet
46is one list entry.
47batch_size (int): the batch size to run texts through the model. Defaults to 16.
48add_start_token (bool): whether to add the start token to the texts,
49so the perplexity can include the probability of the first word. Defaults to True.
50device (str): device to run on, defaults to 'cuda' when available
51Returns:
52perplexity: dictionary containing the perplexity scores for the texts
53in the input list, as well as the mean perplexity. If one of the input texts is
54longer than the max input length of the model, then it is truncated to the
55max length for the perplexity computation.
56Examples:
57Example 1:
58>>> perplexity = datasets.load_metric("perplexity")
59>>> input_texts = ["lorem ipsum", "Happy Birthday!", "Bienvenue"]
60>>> results = perplexity.compute(model_id='gpt2',
61... add_start_token=False,
62... input_texts=input_texts) # doctest:+ELLIPSIS
63>>> print(list(results.keys()))
64['perplexities', 'mean_perplexity']
65>>> print(round(results["mean_perplexity"], 2))
6678.22
67>>> print(round(results["perplexities"][0], 2))
6811.11
69
70Example 2:
71>>> perplexity = datasets.load_metric("perplexity")
72>>> input_texts = datasets.load_dataset("wikitext",
73... "wikitext-2-raw-v1",
74... split="test")["text"][:50]
75>>> input_texts = [s for s in input_texts if s!='']
76>>> results = perplexity.compute(model_id='gpt2',
77... input_texts=input_texts) # doctest:+ELLIPSIS
78>>> print(list(results.keys()))
79['perplexities', 'mean_perplexity']
80>>> print(round(results["mean_perplexity"], 2))
8160.35
82>>> print(round(results["perplexities"][0], 2))
8381.12
84"""
85
86
87@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)88class Perplexity(datasets.Metric):89def _info(self):90return datasets.MetricInfo(91description=_DESCRIPTION,92citation=_CITATION,93inputs_description=_KWARGS_DESCRIPTION,94features=datasets.Features(95{96"input_texts": datasets.Value("string"),97}98),99reference_urls=["https://huggingface.co/docs/transformers/perplexity"],100)101
102def _compute(self, input_texts, model_id, batch_size: int = 16, add_start_token: bool = True, device=None):103if device is not None:104assert device in ["gpu", "cpu", "cuda"], "device should be either gpu or cpu."105if device == "gpu":106device = "cuda"107else:108device = "cuda" if torch.cuda.is_available() else "cpu"109
110model = AutoModelForCausalLM.from_pretrained(model_id)111model = model.to(device)112
113tokenizer = AutoTokenizer.from_pretrained(model_id)114
115# if batch_size > 1 (which generally leads to padding being required), and116# if there is not an already assigned pad_token, assign an existing117# special token to also be the padding token118if tokenizer.pad_token is None and batch_size > 1:119existing_special_tokens = list(tokenizer.special_tokens_map_extended.values())120# check that the model already has at least one special token defined121assert (122len(existing_special_tokens) > 0123), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1."124# assign one of the special tokens to also be the pad token125tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})126
127if add_start_token:128# leave room for <BOS> token to be added:129assert (130tokenizer.bos_token is not None131), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"132max_tokenized_len = model.config.max_length - 1133else:134max_tokenized_len = model.config.max_length135
136encodings = tokenizer(137input_texts,138add_special_tokens=False,139padding=True,140truncation=True,141max_length=max_tokenized_len,142return_tensors="pt",143return_attention_mask=True,144).to(device)145
146encoded_texts = encodings["input_ids"]147attn_masks = encodings["attention_mask"]148
149# check that each input is long enough:150if add_start_token:151assert torch.all(torch.ge(attn_masks.sum(1), 1)), "Each input text must be at least one token long."152else:153assert torch.all(154torch.ge(attn_masks.sum(1), 2)155), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings."156
157ppls = []158loss_fct = CrossEntropyLoss(reduction="none")159
160for start_index in logging.tqdm(range(0, len(encoded_texts), batch_size)):161end_index = min(start_index + batch_size, len(encoded_texts))162encoded_batch = encoded_texts[start_index:end_index]163attn_mask = attn_masks[start_index:end_index]164
165if add_start_token:166bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)).to(device)167encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1)168attn_mask = torch.cat(169[torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_mask], dim=1170)171
172labels = encoded_batch173
174with torch.no_grad():175out_logits = model(encoded_batch, attention_mask=attn_mask).logits176
177shift_logits = out_logits[..., :-1, :].contiguous()178shift_labels = labels[..., 1:].contiguous()179shift_attention_mask_batch = attn_mask[..., 1:].contiguous()180
181perplexity_batch = torch.exp2(182(loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1)183/ shift_attention_mask_batch.sum(1)184)185
186ppls += perplexity_batch.tolist()187
188return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)}189