CSS-LM
993 строки · 46.9 Кб
1# coding=utf-8
2# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
3# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import logging
18from typing import Iterable, Optional, Tuple
19
20import torch
21from torch import Tensor
22from torch.nn import functional as F
23
24
25logger = logging.getLogger(__name__)
26
27
28class GenerationMixin:
29"""
30A class contraining all of the functions supporting generation, to be used as a mixin in PreTrainedModel.
31"""
32
33def prepare_inputs_for_generation(self, input_ids, **kwargs):
34return {"input_ids": input_ids}
35
36def adjust_logits_during_generation(self, logits, **kwargs):
37return logits
38
39def _use_cache(self, outputs, use_cache):
40"""During generation, decide whether to pass the `past` variable to the next forward pass."""
41if len(outputs) <= 1 or use_cache is False:
42return False
43if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
44return False
45return True
46
47def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
48"""repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
49for i in range(batch_size * num_beams):
50for previous_token in set(prev_output_tokens[i].tolist()):
51# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
52if lprobs[i, previous_token] < 0:
53lprobs[i, previous_token] *= repetition_penalty
54else:
55lprobs[i, previous_token] /= repetition_penalty
56
57def postprocess_next_token_scores(
58self,
59scores,
60input_ids,
61no_repeat_ngram_size,
62bad_words_ids,
63cur_len,
64min_length,
65max_length,
66eos_token_id,
67repetition_penalty,
68batch_size,
69num_beams,
70):
71# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
72if repetition_penalty != 1.0:
73self.enforce_repetition_penalty_(
74scores, batch_size, num_beams, input_ids, repetition_penalty,
75)
76
77# set eos token prob to zero if min_length is not reached
78if eos_token_id is not None and cur_len < min_length:
79scores[:, eos_token_id] = -float("inf")
80
81if no_repeat_ngram_size > 0:
82# calculate a list of banned tokens to prevent repetitively generating the same ngrams
83num_batch_hypotheses = batch_size * num_beams
84# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
85banned_batch_tokens = calc_banned_ngram_tokens(
86input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
87)
88for i, banned_tokens in enumerate(banned_batch_tokens):
89scores[i, banned_tokens] = -float("inf")
90
91if bad_words_ids is not None:
92# calculate a list of banned tokens according to bad words
93banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
94
95for i, banned_tokens in enumerate(banned_tokens):
96scores[i, banned_tokens] = -float("inf")
97
98return scores
99
100@torch.no_grad()
101def generate(
102self,
103input_ids: Optional[torch.LongTensor] = None,
104max_length: Optional[int] = None,
105min_length: Optional[int] = None,
106do_sample: Optional[bool] = None,
107early_stopping: Optional[bool] = None,
108num_beams: Optional[int] = None,
109temperature: Optional[float] = None,
110top_k: Optional[int] = None,
111top_p: Optional[float] = None,
112repetition_penalty: Optional[float] = None,
113bad_words_ids: Optional[Iterable[int]] = None,
114bos_token_id: Optional[int] = None,
115pad_token_id: Optional[int] = None,
116eos_token_id: Optional[int] = None,
117length_penalty: Optional[float] = None,
118no_repeat_ngram_size: Optional[int] = None,
119num_return_sequences: Optional[int] = None,
120attention_mask: Optional[torch.LongTensor] = None,
121decoder_start_token_id: Optional[int] = None,
122use_cache: Optional[bool] = None,
123**model_specific_kwargs
124) -> torch.LongTensor:
125r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
126
127Adapted in part from `Facebook's XLM beam search code`_.
128
129.. _`Facebook's XLM beam search code`:
130https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
131
132
133Parameters:
134
135input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
136The sequence used as a prompt for the generation. If `None` the method initializes
137it as an empty `torch.LongTensor` of shape `(1,)`.
138
139max_length: (`optional`) int
140The max length of the sequence to be generated. Between `min_length` and infinity. Default to 20.
141
142min_length: (`optional`) int
143The min length of the sequence to be generated. Between 0 and infinity. Default to 0.
144
145do_sample: (`optional`) bool
146If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
147
148early_stopping: (`optional`) bool
149if set to `True` beam search is stopped when at least `num_beams` sentences finished per batch. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
150
151num_beams: (`optional`) int
152Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
153
154temperature: (`optional`) float
155The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
156
157top_k: (`optional`) int
158The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
159
160top_p: (`optional`) float
161The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
162
163repetition_penalty: (`optional`) float
164The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
165
166pad_token_id: (`optional`) int
167Padding token. Default to specicic model pad_token_id or None if it does not exist.
168
169bos_token_id: (`optional`) int
170BOS token. Defaults to `bos_token_id` as defined in the models config.
171
172eos_token_id: (`optional`) int
173EOS token. Defaults to `eos_token_id` as defined in the models config.
174
175length_penalty: (`optional`) float
176Exponential penalty to the length. Default to 1.
177
178no_repeat_ngram_size: (`optional`) int
179If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once.
180bad_words_ids: (`optional`) list of lists of int
181`bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
182
183num_return_sequences: (`optional`) int
184The number of independently computed returned sequences for each element in the batch. Default to 1.
185
186attention_mask (`optional`) obj: `torch.LongTensor` of same shape as `input_ids`
187Mask to avoid performing attention on padding token indices.
188Mask values selected in ``[0, 1]``:
189``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
190Defaults to `None`.
191
192`What are attention masks? <../glossary.html#attention-mask>`__
193
194decoder_start_token_id=None: (`optional`) int
195If an encoder-decoder model starts decoding with a different token than BOS.
196Defaults to `None` and is changed to `BOS` later.
197
198use_cache: (`optional`) bool
199If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`.
200
201model_specific_kwargs: (`optional`) dict
202Additional model specific kwargs will be forwarded to the `forward` function of the model.
203
204Return:
205
206output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
207sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id`
208
209Examples::
210
211tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
212model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
213outputs = model.generate(max_length=40) # do greedy decoding
214print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
215
216tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
217model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
218input_context = 'The dog'
219input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
220outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
221for i in range(3): # 3 output sequences were generated
222print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
223
224tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
225model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
226input_context = 'The dog'
227input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
228outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3) # 3 generate sequences using by sampling
229for i in range(3): # 3 output sequences were generated
230print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
231
232tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
233model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
234input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
235input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
236outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences
237print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
238
239tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer
240model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
241input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl
242bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
243input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
244outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated
245"""
246
247# We cannot generate if the model does not have a LM head
248if self.get_output_embeddings() is None:
249raise AttributeError(
250"You tried to generate sequences with a model that does not have a LM Head."
251"Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )"
252)
253
254max_length = max_length if max_length is not None else self.config.max_length
255min_length = min_length if min_length is not None else self.config.min_length
256do_sample = do_sample if do_sample is not None else self.config.do_sample
257early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
258use_cache = use_cache if use_cache is not None else self.config.use_cache
259num_beams = num_beams if num_beams is not None else self.config.num_beams
260temperature = temperature if temperature is not None else self.config.temperature
261top_k = top_k if top_k is not None else self.config.top_k
262top_p = top_p if top_p is not None else self.config.top_p
263repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
264bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
265pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
266eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
267length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
268no_repeat_ngram_size = (
269no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
270)
271bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
272num_return_sequences = (
273num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
274)
275decoder_start_token_id = (
276decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
277)
278
279if input_ids is not None:
280batch_size = input_ids.shape[0] # overriden by the input batch_size
281else:
282batch_size = 1
283
284assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
285assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
286assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
287assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
288assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
289assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
290assert temperature > 0, "`temperature` should be strictly positive."
291assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
292assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
293assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
294assert input_ids is not None or (
295isinstance(bos_token_id, int) and bos_token_id >= 0
296), "If input_ids is not defined, `bos_token_id` should be a positive integer."
297assert pad_token_id is None or (
298isinstance(pad_token_id, int) and (pad_token_id >= 0)
299), "`pad_token_id` should be a positive integer."
300assert (eos_token_id is None) or (
301isinstance(eos_token_id, int) and (eos_token_id >= 0)
302), "`eos_token_id` should be a positive integer."
303assert length_penalty > 0, "`length_penalty` should be strictly positive."
304assert (
305isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
306), "`no_repeat_ngram_size` should be a positive integer."
307assert (
308isinstance(num_return_sequences, int) and num_return_sequences > 0
309), "`num_return_sequences` should be a strictly positive integer."
310assert (
311bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
312), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
313
314if input_ids is None:
315assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
316"you should either supply a context to complete as `input_ids` input "
317"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
318)
319input_ids = torch.full(
320(batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device,
321)
322else:
323assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
324
325# not allow to duplicate outputs when greedy decoding
326if do_sample is False:
327if num_beams == 1:
328# no_beam_search greedy generation conditions
329assert (
330num_return_sequences == 1
331), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
332
333else:
334# beam_search greedy generation conditions
335assert (
336num_beams >= num_return_sequences
337), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
338
339# create attention mask if necessary
340# TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
341if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
342attention_mask = input_ids.ne(pad_token_id).long()
343elif attention_mask is None:
344attention_mask = input_ids.new_ones(input_ids.shape)
345
346# set pad_token_id to eos_token_id if not set. Important that this is done after
347# attention_mask is created
348if pad_token_id is None and eos_token_id is not None:
349logger.warning(
350"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
351)
352pad_token_id = eos_token_id
353
354# current position and vocab size
355if hasattr(self.config, "vocab_size"):
356vocab_size = self.config.vocab_size
357elif (
358self.config.is_encoder_decoder
359and hasattr(self.config, "decoder")
360and hasattr(self.config.decoder, "vocab_size")
361):
362vocab_size = self.config.decoder.vocab_size
363
364# set effective batch size and effective batch multiplier according to do_sample
365if do_sample:
366effective_batch_size = batch_size * num_return_sequences
367effective_batch_mult = num_return_sequences
368else:
369effective_batch_size = batch_size
370effective_batch_mult = 1
371
372if self.config.is_encoder_decoder:
373if decoder_start_token_id is None:
374decoder_start_token_id = bos_token_id
375
376assert (
377decoder_start_token_id is not None
378), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
379assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
380assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
381
382# get encoder and store encoder outputs
383encoder = self.get_encoder()
384
385encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask)
386
387# Expand input ids if num_beams > 1 or num_return_sequences > 1
388if num_return_sequences > 1 or num_beams > 1:
389input_ids_len = input_ids.shape[-1]
390input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
391attention_mask = attention_mask.unsqueeze(1).expand(
392batch_size, effective_batch_mult * num_beams, input_ids_len
393)
394
395input_ids = input_ids.contiguous().view(
396effective_batch_size * num_beams, input_ids_len
397) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
398attention_mask = attention_mask.contiguous().view(
399effective_batch_size * num_beams, input_ids_len
400) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
401
402if self.config.is_encoder_decoder:
403# create empty decoder_input_ids
404input_ids = torch.full(
405(effective_batch_size * num_beams, 1),
406decoder_start_token_id,
407dtype=torch.long,
408device=next(self.parameters()).device,
409)
410cur_len = 1
411
412assert (
413batch_size == encoder_outputs[0].shape[0]
414), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "
415
416# expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
417expanded_batch_idxs = (
418torch.arange(batch_size)
419.view(-1, 1)
420.repeat(1, num_beams * effective_batch_mult)
421.view(-1)
422.to(input_ids.device)
423)
424# expand encoder_outputs
425encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])
426
427else:
428encoder_outputs = None
429cur_len = input_ids.shape[-1]
430
431assert (
432cur_len < max_length
433), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
434
435if num_beams > 1:
436output = self._generate_beam_search(
437input_ids,
438cur_len=cur_len,
439max_length=max_length,
440min_length=min_length,
441do_sample=do_sample,
442early_stopping=early_stopping,
443temperature=temperature,
444top_k=top_k,
445top_p=top_p,
446repetition_penalty=repetition_penalty,
447no_repeat_ngram_size=no_repeat_ngram_size,
448bad_words_ids=bad_words_ids,
449pad_token_id=pad_token_id,
450eos_token_id=eos_token_id,
451batch_size=effective_batch_size,
452num_return_sequences=num_return_sequences,
453length_penalty=length_penalty,
454num_beams=num_beams,
455vocab_size=vocab_size,
456encoder_outputs=encoder_outputs,
457attention_mask=attention_mask,
458use_cache=use_cache,
459model_specific_kwargs=model_specific_kwargs,
460)
461else:
462output = self._generate_no_beam_search(
463input_ids,
464cur_len=cur_len,
465max_length=max_length,
466min_length=min_length,
467do_sample=do_sample,
468temperature=temperature,
469top_k=top_k,
470top_p=top_p,
471repetition_penalty=repetition_penalty,
472no_repeat_ngram_size=no_repeat_ngram_size,
473bad_words_ids=bad_words_ids,
474pad_token_id=pad_token_id,
475eos_token_id=eos_token_id,
476batch_size=effective_batch_size,
477encoder_outputs=encoder_outputs,
478attention_mask=attention_mask,
479use_cache=use_cache,
480model_specific_kwargs=model_specific_kwargs,
481)
482
483return output
484
485def _generate_no_beam_search(
486self,
487input_ids,
488cur_len,
489max_length,
490min_length,
491do_sample,
492temperature,
493top_k,
494top_p,
495repetition_penalty,
496no_repeat_ngram_size,
497bad_words_ids,
498pad_token_id,
499eos_token_id,
500batch_size,
501encoder_outputs,
502attention_mask,
503use_cache,
504model_specific_kwargs,
505):
506""" Generate sequences for each example without beam search (num_beams == 1).
507All returned sequence are generated independantly.
508"""
509# length of generated sentences / unfinished sentences
510unfinished_sents = input_ids.new(batch_size).fill_(1)
511sent_lengths = input_ids.new(batch_size).fill_(max_length)
512
513past = (encoder_outputs, None) if encoder_outputs is not None else None
514
515while cur_len < max_length:
516model_inputs = self.prepare_inputs_for_generation(
517input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
518)
519
520outputs = self(**model_inputs)
521next_token_logits = outputs[0][:, -1, :]
522
523scores = self.postprocess_next_token_scores(
524scores=next_token_logits,
525input_ids=input_ids,
526no_repeat_ngram_size=no_repeat_ngram_size,
527bad_words_ids=bad_words_ids,
528cur_len=cur_len,
529min_length=min_length,
530max_length=max_length,
531eos_token_id=eos_token_id,
532repetition_penalty=repetition_penalty,
533batch_size=batch_size,
534num_beams=1,
535)
536
537# if model has past, then set the past variable to speed up decoding
538if self._use_cache(outputs, use_cache):
539past = outputs[1]
540
541if do_sample:
542# Temperature (higher temperature => more likely to sample low probability tokens)
543if temperature != 1.0:
544scores = scores / temperature
545# Top-p/top-k filtering
546next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p)
547# Sample
548probs = F.softmax(next_token_logscores, dim=-1)
549next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
550else:
551# Greedy decoding
552next_token = torch.argmax(next_token_logits, dim=-1)
553
554# update generations and finished sentences
555if eos_token_id is not None:
556# pad finished sentences if eos_token_id exist
557tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
558else:
559tokens_to_add = next_token
560
561# add token and increase length by one
562input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
563cur_len = cur_len + 1
564
565if eos_token_id is not None:
566eos_in_sents = tokens_to_add == eos_token_id
567# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
568is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
569sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
570# unfinished_sents is set to zero if eos in sentence
571unfinished_sents.mul_((~eos_in_sents).long())
572
573# stop when there is a </s> in each sentence, or if we exceed the maximul length
574if unfinished_sents.max() == 0:
575break
576
577# extend attention_mask for new generated input if only decoder
578if self.config.is_encoder_decoder is False:
579attention_mask = torch.cat(
580[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
581)
582
583return input_ids
584
585def _generate_beam_search(
586self,
587input_ids,
588cur_len,
589max_length,
590min_length,
591do_sample,
592early_stopping,
593temperature,
594top_k,
595top_p,
596repetition_penalty,
597no_repeat_ngram_size,
598bad_words_ids,
599pad_token_id,
600eos_token_id,
601batch_size,
602num_return_sequences,
603length_penalty,
604num_beams,
605vocab_size,
606encoder_outputs,
607attention_mask,
608use_cache,
609model_specific_kwargs,
610):
611""" Generate sequences for each example with beam search.
612"""
613
614# generated hypotheses
615generated_hyps = [
616BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
617for _ in range(batch_size)
618]
619
620# scores for each sentence in the beam
621beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
622
623# for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
624if do_sample is False:
625beam_scores[:, 1:] = -1e9
626beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
627
628# cache compute states
629past = (encoder_outputs, None) if encoder_outputs is not None else None
630
631# done sentences
632done = [False for _ in range(batch_size)]
633
634while cur_len < max_length:
635model_inputs = self.prepare_inputs_for_generation(
636input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
637)
638outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
639next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
640
641# if model has past, then set the past variable to speed up decoding
642if self._use_cache(outputs, use_cache):
643past = outputs[1]
644if self.config.is_encoder_decoder and do_sample is False:
645# TODO (PVP) still a bit hacky here - there might be a better solution
646next_token_logits = self.adjust_logits_during_generation(
647next_token_logits, cur_len=cur_len, max_length=max_length
648)
649
650scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
651
652scores = self.postprocess_next_token_scores(
653scores=scores,
654input_ids=input_ids,
655no_repeat_ngram_size=no_repeat_ngram_size,
656bad_words_ids=bad_words_ids,
657cur_len=cur_len,
658min_length=min_length,
659max_length=max_length,
660eos_token_id=eos_token_id,
661repetition_penalty=repetition_penalty,
662batch_size=batch_size,
663num_beams=num_beams,
664)
665
666assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
667scores.shape, (batch_size * num_beams, vocab_size)
668)
669
670if do_sample:
671_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
672# Temperature
673if temperature != 1.0:
674_scores = _scores / temperature
675# Top-p/top-k filtering
676_scores = top_k_top_p_filtering(
677_scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
678) # (batch_size * num_beams, vocab_size)
679# re-organize to group the beam together to sample from all beam_idxs
680_scores = _scores.contiguous().view(
681batch_size, num_beams * vocab_size
682) # (batch_size, num_beams * vocab_size)
683
684# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
685probs = F.softmax(_scores, dim=-1)
686next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2)
687# Compute next scores
688next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
689# sort the sampled vector to make sure that the first num_beams samples are the best
690next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
691next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2)
692
693else:
694next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
695
696# re-organize to group the beam together (we are keeping top hypothesis accross beams)
697next_scores = next_scores.view(
698batch_size, num_beams * vocab_size
699) # (batch_size, num_beams * vocab_size)
700
701next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
702
703assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
704
705# next batch beam content
706next_batch_beam = []
707
708# for each sentence
709for batch_idx in range(batch_size):
710
711# if we are done with this sentence, add a pad token
712if done[batch_idx]:
713assert (
714len(generated_hyps[batch_idx]) >= num_beams
715), "Batch can only be done if at least {} beams have been generated".format(num_beams)
716assert (
717eos_token_id is not None and pad_token_id is not None
718), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
719next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
720continue
721
722# next sentence beam content, this will get added to next_batch_beam
723next_sent_beam = []
724
725# next tokens for this sentence
726for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
727zip(next_tokens[batch_idx], next_scores[batch_idx])
728):
729# get beam and token IDs
730beam_id = beam_token_id // vocab_size
731token_id = beam_token_id % vocab_size
732
733effective_beam_id = batch_idx * num_beams + beam_id
734# add to generated hypotheses if end of sentence
735if (eos_token_id is not None) and (token_id.item() == eos_token_id):
736# if beam_token does not belong to top num_beams tokens, it should not be added
737is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
738if is_beam_token_worse_than_top_num_beams:
739continue
740generated_hyps[batch_idx].add(
741input_ids[effective_beam_id].clone(), beam_token_score.item(),
742)
743else:
744# add next predicted token since it is not eos_token
745next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
746
747# once the beam for next step is full, don't add more tokens to it.
748if len(next_sent_beam) == num_beams:
749break
750
751# Check if we are done so that we can save a pad step if all(done)
752done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
753next_scores[batch_idx].max().item(), cur_len
754)
755
756# update next beam content
757assert len(next_sent_beam) == num_beams, "Beam should always be full"
758next_batch_beam.extend(next_sent_beam)
759assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step"
760
761# stop when we are done with each sentence
762if all(done):
763break
764
765# sanity check / prepare next batch
766assert len(next_batch_beam) == batch_size * num_beams
767beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
768beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
769beam_idx = input_ids.new([x[2] for x in next_batch_beam])
770
771# re-order batch and update current length
772input_ids = input_ids[beam_idx, :]
773input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
774cur_len = cur_len + 1
775
776# re-order internal states
777if past is not None:
778past = self._reorder_cache(past, beam_idx)
779
780# extend attention_mask for new generated input if only decoder
781if self.config.is_encoder_decoder is False:
782attention_mask = torch.cat(
783[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
784)
785
786# finalize all open beam hypotheses and add to generated hypotheses
787for batch_idx in range(batch_size):
788if done[batch_idx]:
789continue
790
791# test that beam scores match previously calculated scores if not eos and batch_idx not done
792if eos_token_id is not None and all(
793(token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx]
794):
795assert torch.all(
796next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
797), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
798next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx],
799)
800
801# need to add best num_beams hypotheses to generated hyps
802for beam_id in range(num_beams):
803effective_beam_id = batch_idx * num_beams + beam_id
804final_score = beam_scores[effective_beam_id].item()
805final_tokens = input_ids[effective_beam_id]
806generated_hyps[batch_idx].add(final_tokens, final_score)
807
808# depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
809output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
810output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
811
812# select the best hypotheses
813sent_lengths = input_ids.new(output_batch_size)
814best = []
815
816# retrieve best hypotheses
817for i, hypotheses in enumerate(generated_hyps):
818sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
819for j in range(output_num_return_sequences_per_batch):
820effective_batch_idx = output_num_return_sequences_per_batch * i + j
821best_hyp = sorted_hyps.pop()[1]
822sent_lengths[effective_batch_idx] = len(best_hyp)
823best.append(best_hyp)
824
825# shorter batches are padded
826if sent_lengths.min().item() != sent_lengths.max().item():
827assert pad_token_id is not None, "`Pad_token_id` has to be defined"
828sent_max_len = min(sent_lengths.max().item() + 1, max_length)
829decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
830
831# fill with hypothesis and eos_token_id if necessary
832for i, hypo in enumerate(best):
833decoded[i, : sent_lengths[i]] = hypo
834if sent_lengths[i] < max_length:
835decoded[i, sent_lengths[i]] = eos_token_id
836else:
837# none of the hypotheses have an eos_token
838assert (len(hypo) == max_length for hypo in best)
839decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
840
841return decoded
842
843@staticmethod
844def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]:
845return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
846
847
848def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None:
849"""Copied from fairseq for no_repeat_ngram in beam_search"""
850if cur_len + 1 < no_repeat_ngram_size:
851# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
852return [[] for _ in range(num_hypos)]
853generated_ngrams = [{} for _ in range(num_hypos)]
854for idx in range(num_hypos):
855gen_tokens = prev_input_ids[idx].tolist()
856generated_ngram = generated_ngrams[idx]
857for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
858prev_ngram_tuple = tuple(ngram[:-1])
859generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
860
861def _get_generated_ngrams(hypo_idx):
862# Before decoding the next token, prevent decoding of ngrams that have already appeared
863start_idx = cur_len + 1 - no_repeat_ngram_size
864ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
865return generated_ngrams[hypo_idx].get(ngram_idx, [])
866
867banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
868return banned_tokens
869
870
871def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]:
872banned_tokens = []
873
874def _tokens_match(prev_tokens, tokens):
875if len(tokens) == 0:
876# if bad word tokens is just one token always ban it
877return True
878if len(tokens) > len(prev_input_ids):
879# if bad word tokens are longer then prev input_ids they can't be equal
880return False
881
882if prev_tokens[-len(tokens) :] == tokens:
883# if tokens match
884return True
885else:
886return False
887
888for prev_input_ids_slice in prev_input_ids:
889banned_tokens_slice = []
890
891for banned_token_seq in bad_words_ids:
892assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
893bad_words_ids
894)
895
896if _tokens_match(prev_input_ids_slice.tolist(), banned_token_seq[:-1]) is False:
897# if tokens do not match continue
898continue
899
900banned_tokens_slice.append(banned_token_seq[-1])
901
902banned_tokens.append(banned_tokens_slice)
903
904return banned_tokens
905
906
907def top_k_top_p_filtering(
908logits: Tensor,
909top_k: int = 0,
910top_p: float = 1.0,
911filter_value: float = -float("Inf"),
912min_tokens_to_keep: int = 1,
913) -> Tensor:
914""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
915Args:
916logits: logits distribution shape (batch size, vocabulary size)
917if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
918if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
919Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
920Make sure we keep at least min_tokens_to_keep per batch example in the output
921From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
922"""
923if top_k > 0:
924top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
925# Remove all tokens with a probability less than the last token of the top-k
926indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
927logits[indices_to_remove] = filter_value
928
929if top_p < 1.0:
930sorted_logits, sorted_indices = torch.sort(logits, descending=True)
931cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
932
933# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
934sorted_indices_to_remove = cumulative_probs > top_p
935if min_tokens_to_keep > 1:
936# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
937sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
938# Shift the indices to the right to keep also the first token above the threshold
939sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
940sorted_indices_to_remove[..., 0] = 0
941
942# scatter sorted tensors to original indexing
943indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
944logits[indices_to_remove] = filter_value
945return logits
946
947
948class BeamHypotheses(object):
949def __init__(self, num_beams, max_length, length_penalty, early_stopping):
950"""
951Initialize n-best list of hypotheses.
952"""
953self.max_length = max_length - 1 # ignoring bos_token
954self.length_penalty = length_penalty
955self.early_stopping = early_stopping
956self.num_beams = num_beams
957self.beams = []
958self.worst_score = 1e9
959
960def __len__(self):
961"""
962Number of hypotheses in the list.
963"""
964return len(self.beams)
965
966def add(self, hyp, sum_logprobs):
967"""
968Add a new hypothesis to the list.
969"""
970score = sum_logprobs / len(hyp) ** self.length_penalty
971if len(self) < self.num_beams or score > self.worst_score:
972self.beams.append((score, hyp))
973if len(self) > self.num_beams:
974sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
975del self.beams[sorted_scores[0][1]]
976self.worst_score = sorted_scores[1][0]
977else:
978self.worst_score = min(score, self.worst_score)
979
980def is_done(self, best_sum_logprobs, cur_len):
981"""
982If there are enough hypotheses and that none of the hypotheses being generated
983can become better than the worst one in the heap, then we are done with this sentence.
984"""
985
986if len(self) < self.num_beams:
987return False
988elif self.early_stopping:
989return True
990else:
991cur_score = best_sum_logprobs / cur_len ** self.length_penalty
992ret = self.worst_score >= cur_score
993return ret
994