CSS-LM

Форк
0
/
generation_tf_utils.py 
1097 строк · 52.6 Кб
1
# coding=utf-8
2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16

17
import logging
18

19
import numpy as np
20
import tensorflow as tf
21

22

23
logger = logging.getLogger(__name__)
24

25

26
class TFGenerationMixin:
27
    """
28
    A class contraining all of the functions supporting generation, to be used as a mixin in TFPreTrainedModel.
29
    """
30

31
    def prepare_inputs_for_generation(self, inputs, **kwargs):
32
        return {"inputs": inputs}
33

34
    def _use_cache(self, outputs, use_cache):
35
        """During generation, decide whether to pass the `past` variable to the next forward pass."""
36
        if len(outputs) <= 1 or use_cache is False:
37
            return False
38
        if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
39
            return False
40
        return True
41

42
    def generate(
43
        self,
44
        input_ids=None,
45
        max_length=None,
46
        min_length=None,
47
        do_sample=None,
48
        early_stopping=None,
49
        num_beams=None,
50
        temperature=None,
51
        top_k=None,
52
        top_p=None,
53
        repetition_penalty=None,
54
        bad_words_ids=None,
55
        bos_token_id=None,
56
        pad_token_id=None,
57
        eos_token_id=None,
58
        length_penalty=None,
59
        no_repeat_ngram_size=None,
60
        num_return_sequences=None,
61
        attention_mask=None,
62
        decoder_start_token_id=None,
63
        use_cache=None,
64
    ):
65
        r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
66
        and beam-search.
67

68
        Adapted in part from `Facebook's XLM beam search code`_.
69

70
        .. _`Facebook's XLM beam search code`:
71
           https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
72

73

74
        Parameters:
75

76
            input_ids: (`optional`) `tf.Tensor` of `dtype=tf.int32` of shape `(batch_size, sequence_length)`
77
                The sequence used as a prompt for the generation. If `None` the method initializes
78
                it as an empty `tf.Tensor` of shape `(1,)`.
79

80
            max_length: (`optional`) int
81
                The max length of the sequence to be generated.  Between 1 and infinity. Default to 20.
82

83
            min_length: (`optional`) int
84
                The min length of the sequence to be generated.  Between 0 and infinity. Default to 0.
85
            do_sample: (`optional`) bool
86
                If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
87

88
            early_stopping: (`optional`) bool
89
                if set to `True` beam search is stopped when at least `num_beams` sentences finished per batch. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
90

91
            num_beams: (`optional`) int
92
                Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
93

94
            temperature: (`optional`) float
95
                The value used to module the next token probabilities. Must be strictely positive. Default to 1.0.
96

97
            top_k: (`optional`) int
98
                The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
99

100
            top_p: (`optional`) float
101
                The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
102

103
            repetition_penalty: (`optional`) float
104
                The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
105

106
            bos_token_id: (`optional`) int
107
                Beginning of sentence token if no prompt is provided. Default to specicic model bos_token_id or None if it does not exist.
108

109
            pad_token_id: (`optional`) int
110
                Pad token. Defaults to pad_token_id as defined in the models config.
111

112
            eos_token_id: (`optional`) int
113
                EOS token. Defaults to eos_token_id as defined in the models config.
114

115
            length_penalty: (`optional`) float
116
                Exponential penalty to the length. Default to 1.
117

118
            no_repeat_ngram_size: (`optional`) int
119
                If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once.
120

121
            bad_words_ids: (`optional`) list of lists of int
122
                `bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
123

124
            num_return_sequences: (`optional`) int
125
                The number of independently computed returned sequences for each element in the batch. Default to 1.
126

127
            attention_mask (`optional`) obj: `tf.Tensor` with `dtype=tf.int32` of same shape as `input_ids`
128
                Mask to avoid performing attention on padding token indices.
129
                Mask values selected in ``[0, 1]``:
130
                ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
131
                Defaults to `None`.
132

133
                `What are attention masks? <../glossary.html#attention-mask>`__
134

135
            decoder_start_token_id=None: (`optional`) int
136
                If an encoder-decoder model starts decoding with a different token than BOS.
137
                Defaults to `None` and is changed to `BOS` later.
138

139
            use_cache: (`optional`) bool
140
                If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`.
141

142
        Return:
143

144
            output: `tf.Tensor` of `dtype=tf.int32` shape `(batch_size * num_return_sequences, sequence_length)`
145
                sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id`
146

147
        Examples::
148

149
            tokenizer = AutoTokenizer.from_pretrained('distilgpt2')   # Initialize tokenizer
150
            model = TFAutoModelWithLMHead.from_pretrained('distilgpt2')    # Download model and configuration from S3 and cache.
151
            outputs = model.generate(max_length=40)  # do greedy decoding
152
            print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
153

154
            tokenizer = AutoTokenizer.from_pretrained('openai-gpt')   # Initialize tokenizer
155
            model = TFAutoModelWithLMHead.from_pretrained('openai-gpt')    # Download model and configuration from S3 and cache.
156
            input_context = 'The dog'
157
            input_ids = tokenizer.encode(input_context, return_tensors='tf')  # encode input context
158
            outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5)  # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
159
            for i in range(3): #  3 output sequences were generated
160
                print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
161

162
            tokenizer = AutoTokenizer.from_pretrained('distilgpt2')   # Initialize tokenizer
163
            model = TFAutoModelWithLMHead.from_pretrained('distilgpt2')    # Download model and configuration from S3 and cache.
164
            input_context = 'The dog'
165
            input_ids = tokenizer.encode(input_context, return_tensors='tf')  # encode input context
166
            outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3)  # 3 generate sequences using by sampling
167
            for i in range(3): #  3 output sequences were generated
168
                print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
169

170
            tokenizer = AutoTokenizer.from_pretrained('ctrl')   # Initialize tokenizer
171
            model = TFAutoModelWithLMHead.from_pretrained('ctrl')    # Download model and configuration from S3 and cache.
172
            input_context = 'Legal My neighbor is'  # "Legal" is one of the control codes for ctrl
173
            input_ids = tokenizer.encode(input_context, return_tensors='tf')  # encode input context
174
            outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2)  # generate sequences
175
            print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
176

177
            tokenizer = AutoTokenizer.from_pretrained('gpt2')   # Initialize tokenizer
178
            model = TFAutoModelWithLMHead.from_pretrained('gpt2')    # Download model and configuration from S3 and cache.
179
            input_context = 'My cute dog'  # "Legal" is one of the control codes for ctrl
180
            bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
181
            input_ids = tokenizer.encode(input_context, return_tensors='tf')  # encode input context
182
            outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)  # generate sequences without allowing bad_words to be generated
183
        """
184

185
        # We cannot generate if the model does not have a LM head
186
        if self.get_output_embeddings() is None:
187
            raise AttributeError(
188
                "You tried to generate sequences with a model that does not have a LM Head."
189
                "Please use another model class (e.g. `TFOpenAIGPTLMHeadModel`, `TFXLNetLMHeadModel`, `TFGPT2LMHeadModel`, `TFCTRLLMHeadModel`, `TFT5ForConditionalGeneration`, `TFTransfoXLLMHeadModel`)"
190
            )
191

192
        max_length = max_length if max_length is not None else self.config.max_length
193
        min_length = min_length if min_length is not None else self.config.min_length
194
        do_sample = do_sample if do_sample is not None else self.config.do_sample
195
        early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
196
        use_cache = use_cache if use_cache is not None else self.config.use_cache
197
        num_beams = num_beams if num_beams is not None else self.config.num_beams
198
        temperature = temperature if temperature is not None else self.config.temperature
199
        top_k = top_k if top_k is not None else self.config.top_k
200
        top_p = top_p if top_p is not None else self.config.top_p
201
        repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
202
        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
203
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
204
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
205
        length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
206
        no_repeat_ngram_size = (
207
            no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
208
        )
209
        bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
210
        num_return_sequences = (
211
            num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
212
        )
213
        decoder_start_token_id = (
214
            decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
215
        )
216

217
        if input_ids is not None:
218
            batch_size = shape_list(input_ids)[0]  # overriden by the input batch_size
219
        else:
220
            batch_size = 1
221

222
        assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
223
        assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
224
        assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
225
        assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
226
        assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
227
        assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
228
        assert temperature > 0, "`temperature` should be strictely positive."
229
        assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
230
        assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
231
        assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
232
        assert input_ids is not None or (
233
            isinstance(bos_token_id, int) and bos_token_id >= 0
234
        ), "If input_ids is not defined, `bos_token_id` should be a positive integer."
235
        assert pad_token_id is None or (
236
            isinstance(pad_token_id, int) and (pad_token_id >= 0)
237
        ), "`pad_token_id` should be a positive integer."
238
        assert (eos_token_id is None) or (
239
            isinstance(eos_token_id, int) and (eos_token_id >= 0)
240
        ), "`eos_token_id` should be a positive integer."
241
        assert length_penalty > 0, "`length_penalty` should be strictely positive."
242
        assert (
243
            isinstance(num_return_sequences, int) and num_return_sequences > 0
244
        ), "`num_return_sequences` should be a strictely positive integer."
245
        assert (
246
            bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
247
        ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
248

249
        if input_ids is None:
250
            assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
251
                "you should either supply a context to complete as `input_ids` input "
252
                "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
253
            )
254
            input_ids = tf.fill((batch_size, 1), bos_token_id)
255
        else:
256
            assert len(shape_list(input_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)."
257

258
        # not allow to duplicate outputs when greedy decoding
259
        if do_sample is False:
260
            if num_beams == 1:
261
                # no_beam_search greedy generation conditions
262
                assert (
263
                    num_return_sequences == 1
264
                ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
265

266
            else:
267
                # beam_search greedy generation conditions
268
                assert (
269
                    num_beams >= num_return_sequences
270
                ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
271

272
        # create attention mask if necessary
273
        # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
274
        if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids.numpy()):
275
            attention_mask = tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=tf.int32)
276
        elif attention_mask is None:
277
            attention_mask = tf.ones_like(input_ids)
278

279
        if pad_token_id is None and eos_token_id is not None:
280
            logger.warning(
281
                "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
282
            )
283
            pad_token_id = eos_token_id
284

285
        # current position and vocab size
286
        cur_len = shape_list(input_ids)[1]
287
        vocab_size = self.config.vocab_size
288

289
        # set effective batch size and effective batch multiplier according to do_sample
290
        if do_sample:
291
            effective_batch_size = batch_size * num_return_sequences
292
            effective_batch_mult = num_return_sequences
293
        else:
294
            effective_batch_size = batch_size
295
            effective_batch_mult = 1
296

297
        if self.config.is_encoder_decoder:
298
            if decoder_start_token_id is None:
299
                decoder_start_token_id = bos_token_id
300

301
            assert (
302
                decoder_start_token_id is not None
303
            ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
304
            assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
305
            assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
306

307
            # get encoder and store encoder outputs
308
            encoder = self.get_encoder()
309

310
            encoder_outputs = encoder(input_ids, attention_mask=attention_mask)
311

312
        # Expand input ids if num_beams > 1 or num_return_sequences > 1
313
        if num_return_sequences > 1 or num_beams > 1:
314
            input_ids_len = shape_list(input_ids)[-1]
315
            input_ids = tf.broadcast_to(
316
                tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
317
            )
318
            attention_mask = tf.broadcast_to(
319
                tf.expand_dims(attention_mask, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
320
            )
321
            input_ids = tf.reshape(
322
                input_ids, (effective_batch_size * num_beams, input_ids_len)
323
            )  # shape: (batch_size * num_return_sequences * num_beams, cur_len)
324
            attention_mask = tf.reshape(
325
                attention_mask, (effective_batch_size * num_beams, input_ids_len)
326
            )  # shape: (batch_size * num_return_sequences * num_beams, cur_len)
327

328
        if self.config.is_encoder_decoder:
329

330
            # create empty decoder_input_ids
331
            input_ids = tf.ones((effective_batch_size * num_beams, 1), dtype=tf.int32,) * decoder_start_token_id
332
            cur_len = 1
333

334
            assert (
335
                batch_size == encoder_outputs[0].shape[0]
336
            ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "
337

338
            # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
339
            expanded_batch_idxs = tf.reshape(
340
                tf.repeat(tf.expand_dims(tf.range(batch_size), -1), repeats=num_beams * effective_batch_mult, axis=1),
341
                shape=(-1,),
342
            )
343
            # expand encoder_outputs
344
            encoder_outputs = (tf.gather(encoder_outputs[0], expanded_batch_idxs, axis=0), *encoder_outputs[1:])
345

346
        else:
347
            encoder_outputs = None
348
            cur_len = shape_list(input_ids)[-1]
349

350
        assert (
351
            cur_len < max_length
352
        ), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
353

354
        if num_beams > 1:
355
            output = self._generate_beam_search(
356
                input_ids,
357
                cur_len=cur_len,
358
                max_length=max_length,
359
                min_length=min_length,
360
                do_sample=do_sample,
361
                early_stopping=early_stopping,
362
                temperature=temperature,
363
                top_k=top_k,
364
                top_p=top_p,
365
                repetition_penalty=repetition_penalty,
366
                no_repeat_ngram_size=no_repeat_ngram_size,
367
                bad_words_ids=bad_words_ids,
368
                bos_token_id=bos_token_id,
369
                pad_token_id=pad_token_id,
370
                eos_token_id=eos_token_id,
371
                decoder_start_token_id=decoder_start_token_id,
372
                batch_size=effective_batch_size,
373
                num_return_sequences=num_return_sequences,
374
                length_penalty=length_penalty,
375
                num_beams=num_beams,
376
                vocab_size=vocab_size,
377
                encoder_outputs=encoder_outputs,
378
                attention_mask=attention_mask,
379
                use_cache=use_cache,
380
            )
381
        else:
382
            output = self._generate_no_beam_search(
383
                input_ids,
384
                cur_len=cur_len,
385
                max_length=max_length,
386
                min_length=min_length,
387
                do_sample=do_sample,
388
                temperature=temperature,
389
                top_k=top_k,
390
                top_p=top_p,
391
                repetition_penalty=repetition_penalty,
392
                no_repeat_ngram_size=no_repeat_ngram_size,
393
                bad_words_ids=bad_words_ids,
394
                bos_token_id=bos_token_id,
395
                pad_token_id=pad_token_id,
396
                eos_token_id=eos_token_id,
397
                decoder_start_token_id=decoder_start_token_id,
398
                batch_size=effective_batch_size,
399
                vocab_size=vocab_size,
400
                encoder_outputs=encoder_outputs,
401
                attention_mask=attention_mask,
402
                use_cache=use_cache,
403
            )
404

405
        return output
406

407
    def _generate_no_beam_search(
408
        self,
409
        input_ids,
410
        cur_len,
411
        max_length,
412
        min_length,
413
        do_sample,
414
        temperature,
415
        top_k,
416
        top_p,
417
        repetition_penalty,
418
        no_repeat_ngram_size,
419
        bad_words_ids,
420
        bos_token_id,
421
        pad_token_id,
422
        eos_token_id,
423
        decoder_start_token_id,
424
        batch_size,
425
        vocab_size,
426
        encoder_outputs,
427
        attention_mask,
428
        use_cache,
429
    ):
430
        """ Generate sequences for each example without beam search (num_beams == 1).
431
            All returned sequence are generated independantly.
432
        """
433

434
        # length of generated sentences / unfinished sentences
435
        unfinished_sents = tf.ones_like(input_ids[:, 0])
436
        sent_lengths = tf.ones_like(input_ids[:, 0]) * max_length
437

438
        past = encoder_outputs  # defined for encoder-decoder models, None for decoder-only models
439

440
        while cur_len < max_length:
441
            model_inputs = self.prepare_inputs_for_generation(
442
                input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
443
            )
444
            outputs = self(**model_inputs)
445
            next_token_logits = outputs[0][:, -1, :]
446

447
            # if model has past, then set the past variable to speed up decoding
448
            if self._use_cache(outputs, use_cache):
449
                past = outputs[1]
450

451
            # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
452
            if repetition_penalty != 1.0:
453
                next_token_logits_penalties = _create_next_token_logits_penalties(
454
                    input_ids, next_token_logits, repetition_penalty
455
                )
456
                next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
457

458
            if no_repeat_ngram_size > 0:
459
                # calculate a list of banned tokens to prevent repetitively generating the same ngrams
460
                # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
461
                banned_tokens = calc_banned_ngram_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
462
                # create banned_tokens boolean mask
463
                banned_tokens_indices_mask = []
464
                for banned_tokens_slice in banned_tokens:
465
                    banned_tokens_indices_mask.append(
466
                        [True if token in banned_tokens_slice else False for token in range(vocab_size)]
467
                    )
468

469
                next_token_logits = set_tensor_by_indices_to_value(
470
                    next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
471
                )
472

473
            if bad_words_ids is not None:
474
                # calculate a list of banned tokens according to bad words
475
                banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
476

477
                banned_tokens_indices_mask = []
478
                for banned_tokens_slice in banned_tokens:
479
                    banned_tokens_indices_mask.append(
480
                        [True if token in banned_tokens_slice else False for token in range(vocab_size)]
481
                    )
482

483
                next_token_logits = set_tensor_by_indices_to_value(
484
                    next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
485
                )
486

487
            # set eos token prob to zero if min_length is not reached
488
            if eos_token_id is not None and cur_len < min_length:
489
                # create eos_token_id boolean mask
490
                is_token_logit_eos_token = tf.convert_to_tensor(
491
                    [True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
492
                )
493
                eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
494

495
                next_token_logits = set_tensor_by_indices_to_value(
496
                    next_token_logits, eos_token_indices_mask, -float("inf")
497
                )
498

499
            if do_sample:
500
                # Temperature (higher temperature => more likely to sample low probability tokens)
501
                if temperature != 1.0:
502
                    next_token_logits = next_token_logits / temperature
503
                # Top-p/top-k filtering
504
                next_token_logits = tf_top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
505
                # Sample
506
                next_token = tf.squeeze(
507
                    tf.random.categorical(next_token_logits, dtype=tf.int32, num_samples=1), axis=1
508
                )
509
            else:
510
                # Greedy decoding
511
                next_token = tf.math.argmax(next_token_logits, axis=-1, output_type=tf.int32)
512

513
            # update generations and finished sentences
514
            if eos_token_id is not None:
515
                # pad finished sentences if eos_token_id exist
516
                tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
517
            else:
518
                tokens_to_add = next_token
519

520
            # add token and increase length by one
521
            input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1)
522
            cur_len = cur_len + 1
523

524
            if eos_token_id is not None:
525
                eos_in_sents = tokens_to_add == eos_token_id
526
                # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
527
                is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
528
                    unfinished_sents, tf.cast(eos_in_sents, tf.int32)
529
                )
530
                sent_lengths = (
531
                    sent_lengths * (1 - is_sents_unfinished_and_token_to_add_is_eos)
532
                    + cur_len * is_sents_unfinished_and_token_to_add_is_eos
533
                )
534

535
                # unfinished_sents is set to zero if eos in sentence
536
                unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos
537

538
            # stop when there is a </s> in each sentence, or if we exceed the maximul length
539
            if tf.math.reduce_max(unfinished_sents) == 0:
540
                break
541

542
            # extend attention_mask for new generated input if only decoder
543
            if self.config.is_encoder_decoder is False:
544
                attention_mask = tf.concat(
545
                    [attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
546
                )
547

548
        # if there are different sentences lengths in the batch, some batches have to be padded
549
        min_sent_length = tf.math.reduce_min(sent_lengths)
550
        max_sent_length = tf.math.reduce_max(sent_lengths)
551
        if min_sent_length != max_sent_length:
552
            assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
553
            # finished sents are filled with pad_token
554
            padding = tf.ones([batch_size, max_sent_length.numpy()], dtype=tf.int32) * pad_token_id
555

556
            # create length masks for tf.where operation
557
            broad_casted_sent_lengths = tf.broadcast_to(
558
                tf.expand_dims(sent_lengths, -1), [batch_size, max_sent_length]
559
            )
560
            broad_casted_range = tf.transpose(
561
                tf.broadcast_to(tf.expand_dims(tf.range(max_sent_length), -1), [max_sent_length, batch_size])
562
            )
563

564
            decoded = tf.where(broad_casted_range < broad_casted_sent_lengths, input_ids, padding)
565
        else:
566
            decoded = input_ids
567

568
        return decoded
569

570
    def _generate_beam_search(
571
        self,
572
        input_ids,
573
        cur_len,
574
        max_length,
575
        min_length,
576
        do_sample,
577
        early_stopping,
578
        temperature,
579
        top_k,
580
        top_p,
581
        repetition_penalty,
582
        no_repeat_ngram_size,
583
        bad_words_ids,
584
        bos_token_id,
585
        pad_token_id,
586
        decoder_start_token_id,
587
        eos_token_id,
588
        batch_size,
589
        num_return_sequences,
590
        length_penalty,
591
        num_beams,
592
        vocab_size,
593
        encoder_outputs,
594
        attention_mask,
595
        use_cache,
596
    ):
597
        """ Generate sequences for each example with beam search.
598
        """
599

600
        # generated hypotheses
601
        generated_hyps = [
602
            BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
603
            for _ in range(batch_size)
604
        ]
605

606
        # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
607
        if do_sample is False:
608
            beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32)
609
            beam_scores_end = tf.ones((batch_size, num_beams - 1), dtype=tf.float32) * (-1e9)
610
            beam_scores = tf.concat([beam_scores_begin, beam_scores_end], -1)
611
        else:
612
            beam_scores = tf.zeros((batch_size, num_beams), dtype=tf.float32)
613

614
        beam_scores = tf.reshape(beam_scores, (batch_size * num_beams,))
615

616
        # cache compute states
617
        past = encoder_outputs
618

619
        # done sentences
620
        done = [False for _ in range(batch_size)]
621

622
        while cur_len < max_length:
623
            model_inputs = self.prepare_inputs_for_generation(
624
                input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
625
            )
626
            outputs = self(**model_inputs)  # (batch_size * num_beams, cur_len, vocab_size)
627
            next_token_logits = outputs[0][:, -1, :]  # (batch_size * num_beams, vocab_size)
628

629
            # if model has past, then set the past variable to speed up decoding
630
            if self._use_cache(outputs, use_cache):
631
                past = outputs[1]
632

633
            # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
634
            if repetition_penalty != 1.0:
635
                next_token_logits_penalties = _create_next_token_logits_penalties(
636
                    input_ids, next_token_logits, repetition_penalty
637
                )
638
                next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
639

640
            # Temperature (higher temperature => more likely to sample low probability tokens)
641
            if temperature != 1.0:
642
                next_token_logits = next_token_logits / temperature
643

644
            #             calculate log softmax score
645
            scores = tf.nn.log_softmax(next_token_logits, axis=-1)  # (batch_size * num_beams, vocab_size)
646

647
            # set eos token prob to zero if min_length is not reached
648
            if eos_token_id is not None and cur_len < min_length:
649
                # create eos_token_id boolean mask
650
                num_batch_hypotheses = batch_size * num_beams
651

652
                is_token_logit_eos_token = tf.convert_to_tensor(
653
                    [True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
654
                )
655
                eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [num_batch_hypotheses, vocab_size])
656

657
                scores = set_tensor_by_indices_to_value(scores, eos_token_indices_mask, -float("inf"))
658

659
            if no_repeat_ngram_size > 0:
660
                # calculate a list of banned tokens to prevent repetitively generating the same ngrams
661
                # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
662
                num_batch_hypotheses = batch_size * num_beams
663
                banned_tokens = calc_banned_ngram_tokens(
664
                    input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
665
                )
666
                # create banned_tokens boolean mask
667
                banned_tokens_indices_mask = []
668
                for banned_tokens_slice in banned_tokens:
669
                    banned_tokens_indices_mask.append(
670
                        [True if token in banned_tokens_slice else False for token in range(vocab_size)]
671
                    )
672

673
                scores = set_tensor_by_indices_to_value(
674
                    scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
675
                )
676

677
            if bad_words_ids is not None:
678
                # calculate a list of banned tokens according to bad words
679
                banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
680

681
                banned_tokens_indices_mask = []
682
                for banned_tokens_slice in banned_tokens:
683
                    banned_tokens_indices_mask.append(
684
                        [True if token in banned_tokens_slice else False for token in range(vocab_size)]
685
                    )
686

687
                scores = set_tensor_by_indices_to_value(
688
                    scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
689
                )
690

691
            assert shape_list(scores) == [batch_size * num_beams, vocab_size]
692

693
            if do_sample:
694
                _scores = scores + tf.broadcast_to(
695
                    beam_scores[:, None], (batch_size * num_beams, vocab_size)
696
                )  # (batch_size * num_beams, vocab_size)
697

698
                # Top-p/top-k filtering
699
                _scores = tf_top_k_top_p_filtering(
700
                    _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
701
                )  # (batch_size * num_beams, vocab_size)
702
                # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
703
                _scores = tf.reshape(_scores, (batch_size, num_beams * vocab_size))
704

705
                next_tokens = sample_without_replacement(
706
                    _scores, num_samples=2 * num_beams
707
                )  # (batch_size, 2 * num_beams)
708
                # Compute next scores
709
                next_scores = tf.gather(_scores, next_tokens, batch_dims=1)  # (batch_size, 2 * num_beams)
710

711
                # sort the sampled vector to make sure that the first num_beams samples are the best
712
                next_scores_indices = tf.argsort(next_scores, direction="DESCENDING", axis=1)
713
                next_scores = tf.gather(next_scores, next_scores_indices, batch_dims=1)  # (batch_size, num_beams * 2)
714
                next_tokens = tf.gather(next_tokens, next_scores_indices, batch_dims=1)  # (batch_size, num_beams * 2)
715
            else:
716
                # Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
717
                next_scores = scores + tf.broadcast_to(
718
                    beam_scores[:, None], (batch_size * num_beams, vocab_size)
719
                )  # (batch_size * num_beams, vocab_size)
720

721
                # re-organize to group the beam together (we are keeping top hypothesis accross beams)
722
                next_scores = tf.reshape(
723
                    next_scores, (batch_size, num_beams * vocab_size)
724
                )  # (batch_size, num_beams * vocab_size)
725

726
                next_scores, next_tokens = tf.math.top_k(next_scores, k=2 * num_beams, sorted=True)
727

728
            assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
729

730
            # next batch beam content
731
            next_batch_beam = []
732

733
            # for each sentence
734
            for batch_idx in range(batch_size):
735

736
                # if we are done with this sentence
737
                if done[batch_idx]:
738
                    assert (
739
                        len(generated_hyps[batch_idx]) >= num_beams
740
                    ), "Batch can only be done if at least {} beams have been generated".format(num_beams)
741
                    assert (
742
                        eos_token_id is not None and pad_token_id is not None
743
                    ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
744
                    next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams)  # pad the batch
745
                    continue
746

747
                # next sentence beam content
748
                next_sent_beam = []
749

750
                # next tokens for this sentence
751
                for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
752
                    zip(next_tokens[batch_idx], next_scores[batch_idx])
753
                ):
754
                    # get beam and token IDs
755
                    beam_id = beam_token_id // vocab_size
756
                    token_id = beam_token_id % vocab_size
757

758
                    effective_beam_id = batch_idx * num_beams + beam_id
759
                    # add to generated hypotheses if end of sentence or last iteration
760
                    if (eos_token_id is not None) and (token_id.numpy() == eos_token_id):
761
                        # if beam_token does not belong to top num_beams tokens, it should not be added
762
                        is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
763
                        if is_beam_token_worse_than_top_num_beams:
764
                            continue
765
                        generated_hyps[batch_idx].add(
766
                            tf.identity(input_ids[effective_beam_id]), beam_token_score.numpy()
767
                        )
768
                    else:
769
                        # add next predicted token if it is not eos_token
770
                        next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
771

772
                    # the beam for next step is full
773
                    if len(next_sent_beam) == num_beams:
774
                        break
775

776
                # Check if we are done so that we can save a pad step if all(done)
777
                done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
778
                    tf.reduce_max(next_scores[batch_idx]).numpy(), cur_len
779
                )
780

781
                # update next beam content
782
                assert len(next_sent_beam) == num_beams, "Beam should always be full"
783
                next_batch_beam.extend(next_sent_beam)
784
                assert len(next_batch_beam) == num_beams * (batch_idx + 1)
785

786
            # stop when we are done with each sentence
787
            if all(done):
788
                break
789

790
            # sanity check / prepare next batch
791
            assert len(next_batch_beam) == batch_size * num_beams
792
            beam_scores = tf.convert_to_tensor([x[0] for x in next_batch_beam], dtype=tf.float32)
793
            beam_tokens = tf.convert_to_tensor([x[1] for x in next_batch_beam], dtype=tf.int32)
794
            beam_idx = tf.convert_to_tensor([x[2] for x in next_batch_beam], dtype=tf.int32)
795

796
            # re-order batch and update current length
797
            input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx])
798
            input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1)
799
            cur_len = cur_len + 1
800

801
            # re-order internal states
802
            if past is not None:
803
                past = self._reorder_cache(past, beam_idx)
804

805
            # extend attention_mask for new generated input if only decoder
806
            if self.config.is_encoder_decoder is False:
807
                attention_mask = tf.concat(
808
                    [attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
809
                )
810

811
        # finalize all open beam hypotheses and end to generated hypotheses
812
        for batch_idx in range(batch_size):
813
            # Add all open beam hypothesis to generated_hyps
814
            if done[batch_idx]:
815
                continue
816
            # test that beam scores match previously calculated scores if not eos and batch_idx not done
817
            if eos_token_id is not None and all(
818
                (token_id % vocab_size).numpy().item() != eos_token_id for token_id in next_tokens[batch_idx]
819
            ):
820
                assert tf.reduce_all(
821
                    next_scores[batch_idx, :num_beams] == tf.reshape(beam_scores, (batch_size, num_beams))[batch_idx]
822
                ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
823
                    next_scores[:, :num_beams][batch_idx], tf.reshape(beam_scores, (batch_size, num_beams))[batch_idx]
824
                )
825

826
            # need to add best num_beams hypotheses to generated hyps
827
            for beam_id in range(num_beams):
828
                effective_beam_id = batch_idx * num_beams + beam_id
829
                final_score = beam_scores[effective_beam_id].numpy().item()
830
                final_tokens = input_ids[effective_beam_id]
831
                generated_hyps[batch_idx].add(final_tokens, final_score)
832

833
        # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
834
        output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
835
        output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
836

837
        # select the best hypotheses
838
        sent_lengths_list = []
839
        best = []
840

841
        # retrieve best hypotheses
842
        for i, hypotheses in enumerate(generated_hyps):
843
            sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
844
            for j in range(output_num_return_sequences_per_batch):
845
                best_hyp = sorted_hyps.pop()[1]
846
                sent_lengths_list.append(len(best_hyp))
847
                best.append(best_hyp)
848
        assert output_batch_size == len(best), "Output batch size {} must match output beam hypotheses {}".format(
849
            output_batch_size, len(best)
850
        )
851

852
        sent_lengths = tf.convert_to_tensor(sent_lengths_list, dtype=tf.int32)
853

854
        # shorter batches are filled with pad_token
855
        if tf.reduce_min(sent_lengths).numpy() != tf.reduce_max(sent_lengths).numpy():
856
            assert pad_token_id is not None, "`Pad_token_id` has to be defined"
857
            sent_max_len = min(tf.reduce_max(sent_lengths).numpy() + 1, max_length)
858
            decoded_list = []
859

860
            # fill with hypothesis and eos_token_id if necessary
861
            for i, hypo in enumerate(best):
862
                assert sent_lengths[i] == shape_list(hypo)[0]
863
                # if sent_length is max_len do not pad
864
                if sent_lengths[i] == sent_max_len:
865
                    decoded_slice = hypo
866
                else:
867
                    # else pad to sent_max_len
868
                    num_pad_tokens = sent_max_len - sent_lengths[i]
869
                    padding = pad_token_id * tf.ones((num_pad_tokens,), dtype=tf.int32)
870
                    decoded_slice = tf.concat([hypo, padding], axis=-1)
871

872
                    # finish sentence with EOS token
873
                    if sent_lengths[i] < max_length:
874
                        decoded_slice = tf.where(
875
                            tf.range(sent_max_len, dtype=tf.int32) == sent_lengths[i],
876
                            eos_token_id * tf.ones((sent_max_len,), dtype=tf.int32),
877
                            decoded_slice,
878
                        )
879
                # add to list
880
                decoded_list.append(decoded_slice)
881

882
            decoded = tf.stack(decoded_list)
883
        else:
884
            # none of the hypotheses have an eos_token
885
            assert (len(hypo) == max_length for hypo in best)
886
            decoded = tf.stack(best)
887

888
        return decoded
889

890
    @staticmethod
891
    def _reorder_cache(past, beam_idx):
892
        return tuple(tf.gather(layer_past, beam_idx, axis=1) for layer_past in past)
893

894

895
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
896
    # create logit penalties for already seen input_ids
897
    token_penalties = np.ones(shape_list(logits))
898
    prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()]
899
    for i, prev_input_id in enumerate(prev_input_ids):
900
        logit_penalized = logits[i].numpy()[prev_input_id]
901
        logit_penalties = np.zeros(logit_penalized.shape)
902
        # if previous logit score is < 0 then multiply repetition penalty else divide
903
        logit_penalties[logit_penalized < 0] = repetition_penalty
904
        logit_penalties[logit_penalized > 0] = 1 / repetition_penalty
905
        np.put(token_penalties[i], prev_input_id, logit_penalties)
906
    return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
907

908

909
def calc_banned_ngram_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
910
    # Copied from fairseq for no_repeat_ngram in beam_search"""
911
    if cur_len + 1 < no_repeat_ngram_size:
912
        # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
913
        return [[] for _ in range(num_hypos)]
914
    generated_ngrams = [{} for _ in range(num_hypos)]
915
    for idx in range(num_hypos):
916
        gen_tokens = prev_input_ids[idx].numpy().tolist()
917
        generated_ngram = generated_ngrams[idx]
918
        for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
919
            prev_ngram_tuple = tuple(ngram[:-1])
920
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
921

922
    def _get_generated_ngrams(hypo_idx):
923
        # Before decoding the next token, prevent decoding of ngrams that have already appeared
924
        start_idx = cur_len + 1 - no_repeat_ngram_size
925
        ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].numpy().tolist())
926
        return generated_ngrams[hypo_idx].get(ngram_idx, [])
927

928
    banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
929
    return banned_tokens
930

931

932
def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids):
933
    banned_tokens = []
934

935
    def _tokens_match(prev_tokens, tokens):
936
        if len(tokens) == 0:
937
            # if bad word tokens is just one token always ban it
938
            return True
939
        if len(tokens) > len(prev_input_ids):
940
            # if bad word tokens are longer then prev input_ids they can't be equal
941
            return False
942

943
        if prev_tokens[-len(tokens) :] == tokens:
944
            # if tokens match
945
            return True
946
        else:
947
            return False
948

949
    for prev_input_ids_slice in prev_input_ids:
950
        banned_tokens_slice = []
951

952
        for banned_token_seq in bad_words_ids:
953
            assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
954
                bad_words_ids
955
            )
956

957
            if _tokens_match(prev_input_ids_slice.numpy().tolist(), banned_token_seq[:-1]) is False:
958
                # if tokens do not match continue
959
                continue
960

961
            banned_tokens_slice.append(banned_token_seq[-1])
962

963
        banned_tokens.append(banned_tokens_slice)
964

965
    return banned_tokens
966

967

968
def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
969
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
970
        Args:
971
            logits: logits distribution shape (batch size, vocabulary size)
972
            if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
973
            if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
974
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
975
            Make sure we keep at least min_tokens_to_keep per batch example in the output
976
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
977
    """
978
    logits_shape = shape_list(logits)
979

980
    if top_k > 0:
981
        top_k = min(max(top_k, min_tokens_to_keep), logits_shape[-1])  # Safety check
982
        # Remove all tokens with a probability less than the last token of the top-k
983
        indices_to_remove = logits < tf.math.top_k(logits, k=top_k)[0][..., -1, None]
984
        logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)
985

986
    if top_p < 1.0:
987
        sorted_indices = tf.argsort(logits, direction="DESCENDING")
988
        sorted_logits = tf.gather(
989
            logits, sorted_indices, axis=-1, batch_dims=1
990
        )  # expects logits to be of dim (batch_size, vocab_size)
991

992
        cumulative_probs = tf.math.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
993

994
        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
995
        sorted_indices_to_remove = cumulative_probs > top_p
996

997
        if min_tokens_to_keep > 1:
998
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
999
            sorted_indices_to_remove = tf.concat(
1000
                [
1001
                    tf.zeros_like(sorted_indices_to_remove[:, :min_tokens_to_keep]),
1002
                    sorted_indices_to_remove[:, min_tokens_to_keep:],
1003
                ],
1004
                -1,
1005
            )
1006

1007
        # Shift the indices to the right to keep also the first token above the threshold
1008
        sorted_indices_to_remove = tf.roll(sorted_indices_to_remove, 1, axis=-1)
1009
        sorted_indices_to_remove = tf.concat(
1010
            [tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, 1:]], -1,
1011
        )
1012
        # scatter sorted tensors to original indexing
1013
        indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices)
1014
        logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)
1015
    return logits
1016

1017

1018
def scatter_values_on_batch_indices(values, batch_indices):
1019
    shape = shape_list(batch_indices)
1020
    # broadcast batch dim to shape
1021
    broad_casted_batch_dims = tf.reshape(tf.broadcast_to(tf.expand_dims(tf.range(shape[0]), axis=-1), shape), [1, -1])
1022
    # transform batch_indices to pair_indices
1023
    pair_indices = tf.transpose(tf.concat([broad_casted_batch_dims, tf.reshape(batch_indices, [1, -1])], 0))
1024
    # scatter values to pair indices
1025
    return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), shape)
1026

1027

1028
def set_tensor_by_indices_to_value(tensor, indices, value):
1029
    # create value_tensor since tensor value assignment is not possible in TF
1030
    value_tensor = tf.zeros_like(tensor) + value
1031
    return tf.where(indices, value_tensor, tensor)
1032

1033

1034
def sample_without_replacement(logits, num_samples):
1035
    """
1036
        categorical sampling witouth replacement is currently not implemented
1037
        the gumbel-max trick will do for now
1038
        see https://github.com/tensorflow/tensorflow/issues/9260 for more info
1039
    """
1040
    z = -tf.math.log(tf.random.uniform(shape_list(logits), 0, 1))
1041
    _, indices = tf.nn.top_k(logits + z, num_samples)
1042
    return indices
1043

1044

1045
def shape_list(x):
1046
    """Deal with dynamic shape in tensorflow cleanly."""
1047
    static = x.shape.as_list()
1048
    dynamic = tf.shape(x)
1049
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]
1050

1051

1052
class BeamHypotheses(object):
1053
    def __init__(self, num_beams, max_length, length_penalty, early_stopping):
1054
        """
1055
        Initialize n-best list of hypotheses.
1056
        """
1057
        self.max_length = max_length - 1  # ignoring bos_token
1058
        self.length_penalty = length_penalty
1059
        self.early_stopping = early_stopping
1060
        self.num_beams = num_beams
1061
        self.beams = []
1062
        self.worst_score = 1e9
1063

1064
    def __len__(self):
1065
        """
1066
        Number of hypotheses in the list.
1067
        """
1068
        return len(self.beams)
1069

1070
    def add(self, hyp, sum_logprobs):
1071
        """
1072
        Add a new hypothesis to the list.
1073
        """
1074
        score = sum_logprobs / len(hyp) ** self.length_penalty
1075
        if len(self) < self.num_beams or score > self.worst_score:
1076
            self.beams.append((score, hyp))
1077
            if len(self) > self.num_beams:
1078
                sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
1079
                del self.beams[sorted_scores[0][1]]
1080
                self.worst_score = sorted_scores[1][0]
1081
            else:
1082
                self.worst_score = min(score, self.worst_score)
1083

1084
    def is_done(self, best_sum_logprobs, cur_len):
1085
        """
1086
        If there are enough hypotheses and that none of the hypotheses being generated
1087
        can become better than the worst one in the heap, then we are done with this sentence.
1088
        """
1089

1090
        if len(self) < self.num_beams:
1091
            return False
1092
        elif self.early_stopping:
1093
            return True
1094
        else:
1095
            cur_score = best_sum_logprobs / cur_len ** self.length_penalty
1096
            ret = self.worst_score >= cur_score
1097
            return ret
1098

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.