CSS-LM

Форк
0
/
generation_utils.py 
993 строки · 46.9 Кб
1
# coding=utf-8
2
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16

17
import logging
18
from typing import Iterable, Optional, Tuple
19

20
import torch
21
from torch import Tensor
22
from torch.nn import functional as F
23

24

25
logger = logging.getLogger(__name__)
26

27

28
class GenerationMixin:
29
    """
30
    A class contraining all of the functions supporting generation, to be used as a mixin in PreTrainedModel.
31
    """
32

33
    def prepare_inputs_for_generation(self, input_ids, **kwargs):
34
        return {"input_ids": input_ids}
35

36
    def adjust_logits_during_generation(self, logits, **kwargs):
37
        return logits
38

39
    def _use_cache(self, outputs, use_cache):
40
        """During generation, decide whether to pass the `past` variable to the next forward pass."""
41
        if len(outputs) <= 1 or use_cache is False:
42
            return False
43
        if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
44
            return False
45
        return True
46

47
    def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
48
        """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
49
        for i in range(batch_size * num_beams):
50
            for previous_token in set(prev_output_tokens[i].tolist()):
51
                # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
52
                if lprobs[i, previous_token] < 0:
53
                    lprobs[i, previous_token] *= repetition_penalty
54
                else:
55
                    lprobs[i, previous_token] /= repetition_penalty
56

57
    def postprocess_next_token_scores(
58
        self,
59
        scores,
60
        input_ids,
61
        no_repeat_ngram_size,
62
        bad_words_ids,
63
        cur_len,
64
        min_length,
65
        max_length,
66
        eos_token_id,
67
        repetition_penalty,
68
        batch_size,
69
        num_beams,
70
    ):
71
        # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
72
        if repetition_penalty != 1.0:
73
            self.enforce_repetition_penalty_(
74
                scores, batch_size, num_beams, input_ids, repetition_penalty,
75
            )
76

77
        # set eos token prob to zero if min_length is not reached
78
        if eos_token_id is not None and cur_len < min_length:
79
            scores[:, eos_token_id] = -float("inf")
80

81
        if no_repeat_ngram_size > 0:
82
            # calculate a list of banned tokens to prevent repetitively generating the same ngrams
83
            num_batch_hypotheses = batch_size * num_beams
84
            # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
85
            banned_batch_tokens = calc_banned_ngram_tokens(
86
                input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
87
            )
88
            for i, banned_tokens in enumerate(banned_batch_tokens):
89
                scores[i, banned_tokens] = -float("inf")
90

91
        if bad_words_ids is not None:
92
            # calculate a list of banned tokens according to bad words
93
            banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
94

95
            for i, banned_tokens in enumerate(banned_tokens):
96
                scores[i, banned_tokens] = -float("inf")
97

98
        return scores
99

100
    @torch.no_grad()
101
    def generate(
102
        self,
103
        input_ids: Optional[torch.LongTensor] = None,
104
        max_length: Optional[int] = None,
105
        min_length: Optional[int] = None,
106
        do_sample: Optional[bool] = None,
107
        early_stopping: Optional[bool] = None,
108
        num_beams: Optional[int] = None,
109
        temperature: Optional[float] = None,
110
        top_k: Optional[int] = None,
111
        top_p: Optional[float] = None,
112
        repetition_penalty: Optional[float] = None,
113
        bad_words_ids: Optional[Iterable[int]] = None,
114
        bos_token_id: Optional[int] = None,
115
        pad_token_id: Optional[int] = None,
116
        eos_token_id: Optional[int] = None,
117
        length_penalty: Optional[float] = None,
118
        no_repeat_ngram_size: Optional[int] = None,
119
        num_return_sequences: Optional[int] = None,
120
        attention_mask: Optional[torch.LongTensor] = None,
121
        decoder_start_token_id: Optional[int] = None,
122
        use_cache: Optional[bool] = None,
123
        **model_specific_kwargs
124
    ) -> torch.LongTensor:
125
        r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
126

127
        Adapted in part from `Facebook's XLM beam search code`_.
128

129
        .. _`Facebook's XLM beam search code`:
130
           https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
131

132

133
        Parameters:
134

135
            input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
136
                The sequence used as a prompt for the generation. If `None` the method initializes
137
                it as an empty `torch.LongTensor` of shape `(1,)`.
138

139
            max_length: (`optional`) int
140
                The max length of the sequence to be generated.  Between `min_length` and infinity. Default to 20.
141

142
            min_length: (`optional`) int
143
                The min length of the sequence to be generated.  Between 0 and infinity. Default to 0.
144

145
            do_sample: (`optional`) bool
146
                If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
147

148
            early_stopping: (`optional`) bool
149
                if set to `True` beam search is stopped when at least `num_beams` sentences finished per batch. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
150

151
            num_beams: (`optional`) int
152
                Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
153

154
            temperature: (`optional`) float
155
                The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
156

157
            top_k: (`optional`) int
158
                The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
159

160
            top_p: (`optional`) float
161
                The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
162

163
            repetition_penalty: (`optional`) float
164
                The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
165

166
            pad_token_id: (`optional`) int
167
                Padding token. Default to specicic model pad_token_id or None if it does not exist.
168

169
            bos_token_id: (`optional`) int
170
                BOS token. Defaults to `bos_token_id` as defined in the models config.
171

172
            eos_token_id: (`optional`) int
173
                EOS token. Defaults to `eos_token_id` as defined in the models config.
174

175
            length_penalty: (`optional`) float
176
                Exponential penalty to the length. Default to 1.
177

178
            no_repeat_ngram_size: (`optional`) int
179
                If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once.
180
            bad_words_ids: (`optional`) list of lists of int
181
                `bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
182

183
            num_return_sequences: (`optional`) int
184
                The number of independently computed returned sequences for each element in the batch. Default to 1.
185

186
            attention_mask (`optional`) obj: `torch.LongTensor` of same shape as `input_ids`
187
                Mask to avoid performing attention on padding token indices.
188
                Mask values selected in ``[0, 1]``:
189
                ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
190
                Defaults to `None`.
191

192
                `What are attention masks? <../glossary.html#attention-mask>`__
193

194
            decoder_start_token_id=None: (`optional`) int
195
                If an encoder-decoder model starts decoding with a different token than BOS.
196
                Defaults to `None` and is changed to `BOS` later.
197

198
            use_cache: (`optional`) bool
199
                If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`.
200

201
            model_specific_kwargs: (`optional`) dict
202
                Additional model specific kwargs will be forwarded to the `forward` function of the model.
203

204
        Return:
205

206
            output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
207
                sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id`
208

209
        Examples::
210

211
            tokenizer = AutoTokenizer.from_pretrained('distilgpt2')   # Initialize tokenizer
212
            model = AutoModelWithLMHead.from_pretrained('distilgpt2')    # Download model and configuration from S3 and cache.
213
            outputs = model.generate(max_length=40)  # do greedy decoding
214
            print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
215

216
            tokenizer = AutoTokenizer.from_pretrained('openai-gpt')   # Initialize tokenizer
217
            model = AutoModelWithLMHead.from_pretrained('openai-gpt')    # Download model and configuration from S3 and cache.
218
            input_context = 'The dog'
219
            input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
220
            outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5)  # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
221
            for i in range(3): #  3 output sequences were generated
222
                print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
223

224
            tokenizer = AutoTokenizer.from_pretrained('distilgpt2')   # Initialize tokenizer
225
            model = AutoModelWithLMHead.from_pretrained('distilgpt2')    # Download model and configuration from S3 and cache.
226
            input_context = 'The dog'
227
            input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
228
            outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3)  # 3 generate sequences using by sampling
229
            for i in range(3): #  3 output sequences were generated
230
                print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
231

232
            tokenizer = AutoTokenizer.from_pretrained('ctrl')   # Initialize tokenizer
233
            model = AutoModelWithLMHead.from_pretrained('ctrl')    # Download model and configuration from S3 and cache.
234
            input_context = 'Legal My neighbor is'  # "Legal" is one of the control codes for ctrl
235
            input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
236
            outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2)  # generate sequences
237
            print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
238

239
            tokenizer = AutoTokenizer.from_pretrained('gpt2')   # Initialize tokenizer
240
            model = AutoModelWithLMHead.from_pretrained('gpt2')    # Download model and configuration from S3 and cache.
241
            input_context = 'My cute dog'  # "Legal" is one of the control codes for ctrl
242
            bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
243
            input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
244
            outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)  # generate sequences without allowing bad_words to be generated
245
        """
246

247
        # We cannot generate if the model does not have a LM head
248
        if self.get_output_embeddings() is None:
249
            raise AttributeError(
250
                "You tried to generate sequences with a model that does not have a LM Head."
251
                "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )"
252
            )
253

254
        max_length = max_length if max_length is not None else self.config.max_length
255
        min_length = min_length if min_length is not None else self.config.min_length
256
        do_sample = do_sample if do_sample is not None else self.config.do_sample
257
        early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
258
        use_cache = use_cache if use_cache is not None else self.config.use_cache
259
        num_beams = num_beams if num_beams is not None else self.config.num_beams
260
        temperature = temperature if temperature is not None else self.config.temperature
261
        top_k = top_k if top_k is not None else self.config.top_k
262
        top_p = top_p if top_p is not None else self.config.top_p
263
        repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
264
        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
265
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
266
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
267
        length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
268
        no_repeat_ngram_size = (
269
            no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
270
        )
271
        bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
272
        num_return_sequences = (
273
            num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
274
        )
275
        decoder_start_token_id = (
276
            decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
277
        )
278

279
        if input_ids is not None:
280
            batch_size = input_ids.shape[0]  # overriden by the input batch_size
281
        else:
282
            batch_size = 1
283

284
        assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
285
        assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
286
        assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
287
        assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
288
        assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
289
        assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
290
        assert temperature > 0, "`temperature` should be strictly positive."
291
        assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
292
        assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
293
        assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
294
        assert input_ids is not None or (
295
            isinstance(bos_token_id, int) and bos_token_id >= 0
296
        ), "If input_ids is not defined, `bos_token_id` should be a positive integer."
297
        assert pad_token_id is None or (
298
            isinstance(pad_token_id, int) and (pad_token_id >= 0)
299
        ), "`pad_token_id` should be a positive integer."
300
        assert (eos_token_id is None) or (
301
            isinstance(eos_token_id, int) and (eos_token_id >= 0)
302
        ), "`eos_token_id` should be a positive integer."
303
        assert length_penalty > 0, "`length_penalty` should be strictly positive."
304
        assert (
305
            isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
306
        ), "`no_repeat_ngram_size` should be a positive integer."
307
        assert (
308
            isinstance(num_return_sequences, int) and num_return_sequences > 0
309
        ), "`num_return_sequences` should be a strictly positive integer."
310
        assert (
311
            bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
312
        ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
313

314
        if input_ids is None:
315
            assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
316
                "you should either supply a context to complete as `input_ids` input "
317
                "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
318
            )
319
            input_ids = torch.full(
320
                (batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device,
321
            )
322
        else:
323
            assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
324

325
        # not allow to duplicate outputs when greedy decoding
326
        if do_sample is False:
327
            if num_beams == 1:
328
                # no_beam_search greedy generation conditions
329
                assert (
330
                    num_return_sequences == 1
331
                ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
332

333
            else:
334
                # beam_search greedy generation conditions
335
                assert (
336
                    num_beams >= num_return_sequences
337
                ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
338

339
        # create attention mask if necessary
340
        # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
341
        if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
342
            attention_mask = input_ids.ne(pad_token_id).long()
343
        elif attention_mask is None:
344
            attention_mask = input_ids.new_ones(input_ids.shape)
345

346
        # set pad_token_id to eos_token_id if not set. Important that this is done after
347
        # attention_mask is created
348
        if pad_token_id is None and eos_token_id is not None:
349
            logger.warning(
350
                "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
351
            )
352
            pad_token_id = eos_token_id
353

354
        # current position and vocab size
355
        if hasattr(self.config, "vocab_size"):
356
            vocab_size = self.config.vocab_size
357
        elif (
358
            self.config.is_encoder_decoder
359
            and hasattr(self.config, "decoder")
360
            and hasattr(self.config.decoder, "vocab_size")
361
        ):
362
            vocab_size = self.config.decoder.vocab_size
363

364
        # set effective batch size and effective batch multiplier according to do_sample
365
        if do_sample:
366
            effective_batch_size = batch_size * num_return_sequences
367
            effective_batch_mult = num_return_sequences
368
        else:
369
            effective_batch_size = batch_size
370
            effective_batch_mult = 1
371

372
        if self.config.is_encoder_decoder:
373
            if decoder_start_token_id is None:
374
                decoder_start_token_id = bos_token_id
375

376
            assert (
377
                decoder_start_token_id is not None
378
            ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
379
            assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
380
            assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
381

382
            # get encoder and store encoder outputs
383
            encoder = self.get_encoder()
384

385
            encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask)
386

387
        # Expand input ids if num_beams > 1 or num_return_sequences > 1
388
        if num_return_sequences > 1 or num_beams > 1:
389
            input_ids_len = input_ids.shape[-1]
390
            input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
391
            attention_mask = attention_mask.unsqueeze(1).expand(
392
                batch_size, effective_batch_mult * num_beams, input_ids_len
393
            )
394

395
            input_ids = input_ids.contiguous().view(
396
                effective_batch_size * num_beams, input_ids_len
397
            )  # shape: (batch_size * num_return_sequences * num_beams, cur_len)
398
            attention_mask = attention_mask.contiguous().view(
399
                effective_batch_size * num_beams, input_ids_len
400
            )  # shape: (batch_size * num_return_sequences * num_beams, cur_len)
401

402
        if self.config.is_encoder_decoder:
403
            # create empty decoder_input_ids
404
            input_ids = torch.full(
405
                (effective_batch_size * num_beams, 1),
406
                decoder_start_token_id,
407
                dtype=torch.long,
408
                device=next(self.parameters()).device,
409
            )
410
            cur_len = 1
411

412
            assert (
413
                batch_size == encoder_outputs[0].shape[0]
414
            ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "
415

416
            # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
417
            expanded_batch_idxs = (
418
                torch.arange(batch_size)
419
                .view(-1, 1)
420
                .repeat(1, num_beams * effective_batch_mult)
421
                .view(-1)
422
                .to(input_ids.device)
423
            )
424
            # expand encoder_outputs
425
            encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])
426

427
        else:
428
            encoder_outputs = None
429
            cur_len = input_ids.shape[-1]
430

431
        assert (
432
            cur_len < max_length
433
        ), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
434

435
        if num_beams > 1:
436
            output = self._generate_beam_search(
437
                input_ids,
438
                cur_len=cur_len,
439
                max_length=max_length,
440
                min_length=min_length,
441
                do_sample=do_sample,
442
                early_stopping=early_stopping,
443
                temperature=temperature,
444
                top_k=top_k,
445
                top_p=top_p,
446
                repetition_penalty=repetition_penalty,
447
                no_repeat_ngram_size=no_repeat_ngram_size,
448
                bad_words_ids=bad_words_ids,
449
                pad_token_id=pad_token_id,
450
                eos_token_id=eos_token_id,
451
                batch_size=effective_batch_size,
452
                num_return_sequences=num_return_sequences,
453
                length_penalty=length_penalty,
454
                num_beams=num_beams,
455
                vocab_size=vocab_size,
456
                encoder_outputs=encoder_outputs,
457
                attention_mask=attention_mask,
458
                use_cache=use_cache,
459
                model_specific_kwargs=model_specific_kwargs,
460
            )
461
        else:
462
            output = self._generate_no_beam_search(
463
                input_ids,
464
                cur_len=cur_len,
465
                max_length=max_length,
466
                min_length=min_length,
467
                do_sample=do_sample,
468
                temperature=temperature,
469
                top_k=top_k,
470
                top_p=top_p,
471
                repetition_penalty=repetition_penalty,
472
                no_repeat_ngram_size=no_repeat_ngram_size,
473
                bad_words_ids=bad_words_ids,
474
                pad_token_id=pad_token_id,
475
                eos_token_id=eos_token_id,
476
                batch_size=effective_batch_size,
477
                encoder_outputs=encoder_outputs,
478
                attention_mask=attention_mask,
479
                use_cache=use_cache,
480
                model_specific_kwargs=model_specific_kwargs,
481
            )
482

483
        return output
484

485
    def _generate_no_beam_search(
486
        self,
487
        input_ids,
488
        cur_len,
489
        max_length,
490
        min_length,
491
        do_sample,
492
        temperature,
493
        top_k,
494
        top_p,
495
        repetition_penalty,
496
        no_repeat_ngram_size,
497
        bad_words_ids,
498
        pad_token_id,
499
        eos_token_id,
500
        batch_size,
501
        encoder_outputs,
502
        attention_mask,
503
        use_cache,
504
        model_specific_kwargs,
505
    ):
506
        """ Generate sequences for each example without beam search (num_beams == 1).
507
            All returned sequence are generated independantly.
508
        """
509
        # length of generated sentences / unfinished sentences
510
        unfinished_sents = input_ids.new(batch_size).fill_(1)
511
        sent_lengths = input_ids.new(batch_size).fill_(max_length)
512

513
        past = (encoder_outputs, None) if encoder_outputs is not None else None
514

515
        while cur_len < max_length:
516
            model_inputs = self.prepare_inputs_for_generation(
517
                input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
518
            )
519

520
            outputs = self(**model_inputs)
521
            next_token_logits = outputs[0][:, -1, :]
522

523
            scores = self.postprocess_next_token_scores(
524
                scores=next_token_logits,
525
                input_ids=input_ids,
526
                no_repeat_ngram_size=no_repeat_ngram_size,
527
                bad_words_ids=bad_words_ids,
528
                cur_len=cur_len,
529
                min_length=min_length,
530
                max_length=max_length,
531
                eos_token_id=eos_token_id,
532
                repetition_penalty=repetition_penalty,
533
                batch_size=batch_size,
534
                num_beams=1,
535
            )
536

537
            # if model has past, then set the past variable to speed up decoding
538
            if self._use_cache(outputs, use_cache):
539
                past = outputs[1]
540

541
            if do_sample:
542
                # Temperature (higher temperature => more likely to sample low probability tokens)
543
                if temperature != 1.0:
544
                    scores = scores / temperature
545
                # Top-p/top-k filtering
546
                next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p)
547
                # Sample
548
                probs = F.softmax(next_token_logscores, dim=-1)
549
                next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
550
            else:
551
                # Greedy decoding
552
                next_token = torch.argmax(next_token_logits, dim=-1)
553

554
            # update generations and finished sentences
555
            if eos_token_id is not None:
556
                # pad finished sentences if eos_token_id exist
557
                tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
558
            else:
559
                tokens_to_add = next_token
560

561
            # add token and increase length by one
562
            input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
563
            cur_len = cur_len + 1
564

565
            if eos_token_id is not None:
566
                eos_in_sents = tokens_to_add == eos_token_id
567
                # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
568
                is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
569
                sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
570
                # unfinished_sents is set to zero if eos in sentence
571
                unfinished_sents.mul_((~eos_in_sents).long())
572

573
            # stop when there is a </s> in each sentence, or if we exceed the maximul length
574
            if unfinished_sents.max() == 0:
575
                break
576

577
            # extend attention_mask for new generated input if only decoder
578
            if self.config.is_encoder_decoder is False:
579
                attention_mask = torch.cat(
580
                    [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
581
                )
582

583
        return input_ids
584

585
    def _generate_beam_search(
586
        self,
587
        input_ids,
588
        cur_len,
589
        max_length,
590
        min_length,
591
        do_sample,
592
        early_stopping,
593
        temperature,
594
        top_k,
595
        top_p,
596
        repetition_penalty,
597
        no_repeat_ngram_size,
598
        bad_words_ids,
599
        pad_token_id,
600
        eos_token_id,
601
        batch_size,
602
        num_return_sequences,
603
        length_penalty,
604
        num_beams,
605
        vocab_size,
606
        encoder_outputs,
607
        attention_mask,
608
        use_cache,
609
        model_specific_kwargs,
610
    ):
611
        """ Generate sequences for each example with beam search.
612
        """
613

614
        # generated hypotheses
615
        generated_hyps = [
616
            BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
617
            for _ in range(batch_size)
618
        ]
619

620
        # scores for each sentence in the beam
621
        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
622

623
        # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
624
        if do_sample is False:
625
            beam_scores[:, 1:] = -1e9
626
        beam_scores = beam_scores.view(-1)  # shape (batch_size * num_beams,)
627

628
        # cache compute states
629
        past = (encoder_outputs, None) if encoder_outputs is not None else None
630

631
        # done sentences
632
        done = [False for _ in range(batch_size)]
633

634
        while cur_len < max_length:
635
            model_inputs = self.prepare_inputs_for_generation(
636
                input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
637
            )
638
            outputs = self(**model_inputs)  # (batch_size * num_beams, cur_len, vocab_size)
639
            next_token_logits = outputs[0][:, -1, :]  # (batch_size * num_beams, vocab_size)
640

641
            # if model has past, then set the past variable to speed up decoding
642
            if self._use_cache(outputs, use_cache):
643
                past = outputs[1]
644
            if self.config.is_encoder_decoder and do_sample is False:
645
                # TODO (PVP) still a bit hacky here - there might be a better solution
646
                next_token_logits = self.adjust_logits_during_generation(
647
                    next_token_logits, cur_len=cur_len, max_length=max_length
648
                )
649

650
            scores = F.log_softmax(next_token_logits, dim=-1)  # (batch_size * num_beams, vocab_size)
651

652
            scores = self.postprocess_next_token_scores(
653
                scores=scores,
654
                input_ids=input_ids,
655
                no_repeat_ngram_size=no_repeat_ngram_size,
656
                bad_words_ids=bad_words_ids,
657
                cur_len=cur_len,
658
                min_length=min_length,
659
                max_length=max_length,
660
                eos_token_id=eos_token_id,
661
                repetition_penalty=repetition_penalty,
662
                batch_size=batch_size,
663
                num_beams=num_beams,
664
            )
665

666
            assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
667
                scores.shape, (batch_size * num_beams, vocab_size)
668
            )
669

670
            if do_sample:
671
                _scores = scores + beam_scores[:, None].expand_as(scores)  # (batch_size * num_beams, vocab_size)
672
                # Temperature
673
                if temperature != 1.0:
674
                    _scores = _scores / temperature
675
                # Top-p/top-k filtering
676
                _scores = top_k_top_p_filtering(
677
                    _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
678
                )  # (batch_size * num_beams, vocab_size)
679
                # re-organize to group the beam together to sample from all beam_idxs
680
                _scores = _scores.contiguous().view(
681
                    batch_size, num_beams * vocab_size
682
                )  # (batch_size, num_beams * vocab_size)
683

684
                # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
685
                probs = F.softmax(_scores, dim=-1)
686
                next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)  # (batch_size, num_beams * 2)
687
                # Compute next scores
688
                next_scores = torch.gather(_scores, -1, next_tokens)  # (batch_size, num_beams * 2)
689
                # sort the sampled vector to make sure that the first num_beams samples are the best
690
                next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
691
                next_tokens = torch.gather(next_tokens, -1, next_scores_indices)  # (batch_size, num_beams * 2)
692

693
            else:
694
                next_scores = scores + beam_scores[:, None].expand_as(scores)  # (batch_size * num_beams, vocab_size)
695

696
                # re-organize to group the beam together (we are keeping top hypothesis accross beams)
697
                next_scores = next_scores.view(
698
                    batch_size, num_beams * vocab_size
699
                )  # (batch_size, num_beams * vocab_size)
700

701
                next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
702

703
            assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
704

705
            # next batch beam content
706
            next_batch_beam = []
707

708
            # for each sentence
709
            for batch_idx in range(batch_size):
710

711
                # if we are done with this sentence, add a pad token
712
                if done[batch_idx]:
713
                    assert (
714
                        len(generated_hyps[batch_idx]) >= num_beams
715
                    ), "Batch can only be done if at least {} beams have been generated".format(num_beams)
716
                    assert (
717
                        eos_token_id is not None and pad_token_id is not None
718
                    ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
719
                    next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams)  # pad the batch
720
                    continue
721

722
                # next sentence beam content, this will get added to next_batch_beam
723
                next_sent_beam = []
724

725
                # next tokens for this sentence
726
                for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
727
                    zip(next_tokens[batch_idx], next_scores[batch_idx])
728
                ):
729
                    # get beam and token IDs
730
                    beam_id = beam_token_id // vocab_size
731
                    token_id = beam_token_id % vocab_size
732

733
                    effective_beam_id = batch_idx * num_beams + beam_id
734
                    # add to generated hypotheses if end of sentence
735
                    if (eos_token_id is not None) and (token_id.item() == eos_token_id):
736
                        # if beam_token does not belong to top num_beams tokens, it should not be added
737
                        is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
738
                        if is_beam_token_worse_than_top_num_beams:
739
                            continue
740
                        generated_hyps[batch_idx].add(
741
                            input_ids[effective_beam_id].clone(), beam_token_score.item(),
742
                        )
743
                    else:
744
                        # add next predicted token since it is not eos_token
745
                        next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
746

747
                    # once the beam for next step is full, don't add more tokens to it.
748
                    if len(next_sent_beam) == num_beams:
749
                        break
750

751
                # Check if we are done so that we can save a pad step if all(done)
752
                done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
753
                    next_scores[batch_idx].max().item(), cur_len
754
                )
755

756
                # update next beam content
757
                assert len(next_sent_beam) == num_beams, "Beam should always be full"
758
                next_batch_beam.extend(next_sent_beam)
759
                assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step"
760

761
            # stop when we are done with each sentence
762
            if all(done):
763
                break
764

765
            # sanity check / prepare next batch
766
            assert len(next_batch_beam) == batch_size * num_beams
767
            beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
768
            beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
769
            beam_idx = input_ids.new([x[2] for x in next_batch_beam])
770

771
            # re-order batch and update current length
772
            input_ids = input_ids[beam_idx, :]
773
            input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
774
            cur_len = cur_len + 1
775

776
            # re-order internal states
777
            if past is not None:
778
                past = self._reorder_cache(past, beam_idx)
779

780
            # extend attention_mask for new generated input if only decoder
781
            if self.config.is_encoder_decoder is False:
782
                attention_mask = torch.cat(
783
                    [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
784
                )
785

786
        # finalize all open beam hypotheses and add to generated hypotheses
787
        for batch_idx in range(batch_size):
788
            if done[batch_idx]:
789
                continue
790

791
            # test that beam scores match previously calculated scores if not eos and batch_idx not done
792
            if eos_token_id is not None and all(
793
                (token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx]
794
            ):
795
                assert torch.all(
796
                    next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
797
                ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
798
                    next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx],
799
                )
800

801
            # need to add best num_beams hypotheses to generated hyps
802
            for beam_id in range(num_beams):
803
                effective_beam_id = batch_idx * num_beams + beam_id
804
                final_score = beam_scores[effective_beam_id].item()
805
                final_tokens = input_ids[effective_beam_id]
806
                generated_hyps[batch_idx].add(final_tokens, final_score)
807

808
        # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
809
        output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
810
        output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
811

812
        # select the best hypotheses
813
        sent_lengths = input_ids.new(output_batch_size)
814
        best = []
815

816
        # retrieve best hypotheses
817
        for i, hypotheses in enumerate(generated_hyps):
818
            sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
819
            for j in range(output_num_return_sequences_per_batch):
820
                effective_batch_idx = output_num_return_sequences_per_batch * i + j
821
                best_hyp = sorted_hyps.pop()[1]
822
                sent_lengths[effective_batch_idx] = len(best_hyp)
823
                best.append(best_hyp)
824

825
        # shorter batches are padded
826
        if sent_lengths.min().item() != sent_lengths.max().item():
827
            assert pad_token_id is not None, "`Pad_token_id` has to be defined"
828
            sent_max_len = min(sent_lengths.max().item() + 1, max_length)
829
            decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
830

831
            # fill with hypothesis and eos_token_id if necessary
832
            for i, hypo in enumerate(best):
833
                decoded[i, : sent_lengths[i]] = hypo
834
                if sent_lengths[i] < max_length:
835
                    decoded[i, sent_lengths[i]] = eos_token_id
836
        else:
837
            # none of the hypotheses have an eos_token
838
            assert (len(hypo) == max_length for hypo in best)
839
            decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
840

841
        return decoded
842

843
    @staticmethod
844
    def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]:
845
        return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
846

847

848
def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None:
849
    """Copied from fairseq for no_repeat_ngram in beam_search"""
850
    if cur_len + 1 < no_repeat_ngram_size:
851
        # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
852
        return [[] for _ in range(num_hypos)]
853
    generated_ngrams = [{} for _ in range(num_hypos)]
854
    for idx in range(num_hypos):
855
        gen_tokens = prev_input_ids[idx].tolist()
856
        generated_ngram = generated_ngrams[idx]
857
        for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
858
            prev_ngram_tuple = tuple(ngram[:-1])
859
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
860

861
    def _get_generated_ngrams(hypo_idx):
862
        # Before decoding the next token, prevent decoding of ngrams that have already appeared
863
        start_idx = cur_len + 1 - no_repeat_ngram_size
864
        ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
865
        return generated_ngrams[hypo_idx].get(ngram_idx, [])
866

867
    banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
868
    return banned_tokens
869

870

871
def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]:
872
    banned_tokens = []
873

874
    def _tokens_match(prev_tokens, tokens):
875
        if len(tokens) == 0:
876
            # if bad word tokens is just one token always ban it
877
            return True
878
        if len(tokens) > len(prev_input_ids):
879
            # if bad word tokens are longer then prev input_ids they can't be equal
880
            return False
881

882
        if prev_tokens[-len(tokens) :] == tokens:
883
            # if tokens match
884
            return True
885
        else:
886
            return False
887

888
    for prev_input_ids_slice in prev_input_ids:
889
        banned_tokens_slice = []
890

891
        for banned_token_seq in bad_words_ids:
892
            assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
893
                bad_words_ids
894
            )
895

896
            if _tokens_match(prev_input_ids_slice.tolist(), banned_token_seq[:-1]) is False:
897
                # if tokens do not match continue
898
                continue
899

900
            banned_tokens_slice.append(banned_token_seq[-1])
901

902
        banned_tokens.append(banned_tokens_slice)
903

904
    return banned_tokens
905

906

907
def top_k_top_p_filtering(
908
    logits: Tensor,
909
    top_k: int = 0,
910
    top_p: float = 1.0,
911
    filter_value: float = -float("Inf"),
912
    min_tokens_to_keep: int = 1,
913
) -> Tensor:
914
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
915
        Args:
916
            logits: logits distribution shape (batch size, vocabulary size)
917
            if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
918
            if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
919
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
920
            Make sure we keep at least min_tokens_to_keep per batch example in the output
921
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
922
    """
923
    if top_k > 0:
924
        top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # Safety check
925
        # Remove all tokens with a probability less than the last token of the top-k
926
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
927
        logits[indices_to_remove] = filter_value
928

929
    if top_p < 1.0:
930
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
931
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
932

933
        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
934
        sorted_indices_to_remove = cumulative_probs > top_p
935
        if min_tokens_to_keep > 1:
936
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
937
            sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
938
        # Shift the indices to the right to keep also the first token above the threshold
939
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
940
        sorted_indices_to_remove[..., 0] = 0
941

942
        # scatter sorted tensors to original indexing
943
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
944
        logits[indices_to_remove] = filter_value
945
    return logits
946

947

948
class BeamHypotheses(object):
949
    def __init__(self, num_beams, max_length, length_penalty, early_stopping):
950
        """
951
        Initialize n-best list of hypotheses.
952
        """
953
        self.max_length = max_length - 1  # ignoring bos_token
954
        self.length_penalty = length_penalty
955
        self.early_stopping = early_stopping
956
        self.num_beams = num_beams
957
        self.beams = []
958
        self.worst_score = 1e9
959

960
    def __len__(self):
961
        """
962
        Number of hypotheses in the list.
963
        """
964
        return len(self.beams)
965

966
    def add(self, hyp, sum_logprobs):
967
        """
968
        Add a new hypothesis to the list.
969
        """
970
        score = sum_logprobs / len(hyp) ** self.length_penalty
971
        if len(self) < self.num_beams or score > self.worst_score:
972
            self.beams.append((score, hyp))
973
            if len(self) > self.num_beams:
974
                sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
975
                del self.beams[sorted_scores[0][1]]
976
                self.worst_score = sorted_scores[1][0]
977
            else:
978
                self.worst_score = min(score, self.worst_score)
979

980
    def is_done(self, best_sum_logprobs, cur_len):
981
        """
982
        If there are enough hypotheses and that none of the hypotheses being generated
983
        can become better than the worst one in the heap, then we are done with this sentence.
984
        """
985

986
        if len(self) < self.num_beams:
987
            return False
988
        elif self.early_stopping:
989
            return True
990
        else:
991
            cur_score = best_sum_logprobs / cur_len ** self.length_penalty
992
            ret = self.worst_score >= cur_score
993
            return ret
994

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.