allennlp

Форк
0
/
beam_search_test.py 
821 строка · 31.5 Кб
1
from typing import Dict, Tuple, Union
2

3
import numpy as np
4
import pytest
5
import torch
6

7
from allennlp.common.checks import ConfigurationError
8
from allennlp.common.testing import AllenNlpTestCase
9
from allennlp.nn.beam_search import (
10
    MultinomialSampler,
11
    BeamSearch,
12
    TopKSampler,
13
    TopPSampler,
14
    GumbelSampler,
15
    LengthNormalizedSequenceLogProbabilityScorer,
16
    RepeatedNGramBlockingConstraint,
17
    StepFunctionTypeWithTimestep,
18
    StepFunctionTypeNoTimestep,
19
)
20
from allennlp.common.params import Params
21
from allennlp.nn.util import min_value_of_dtype
22

23

24
# fmt: off
25
transition_probabilities = torch.tensor(
26
    [  # START 1    2    3    4   END
27
        [0.0, 0.4, 0.3, 0.2, 0.1, 0.0],  # START -> j
28
        [0.0, 0.0, 1.0, 0.0, 0.0, 0.0],  # 1 -> j
29
        [0.0, 0.0, 0.0, 1.0, 0.0, 0.0],  # 2 -> j
30
        [0.0, 0.0, 0.0, 0.0, 1.0, 0.0],  # 3 -> j
31
        [0.0, 0.0, 0.0, 0.0, 0.0, 1.0],  # 4 -> j
32
        [0.2, 0.1, 0.2, 0.2, 0.2, 0.1],  # END -> j (doesn't matter)
33
    ]
34
)
35

36
# A transition matrix that favors shorter sequences over longer ones
37
short_sequence_transition_probabilities = torch.tensor(
38
    [  # START 1    2    3    4   END
39
        [0.0, 0.1, 0.0, 0.0, 0.0, 0.9],  # START -> j
40
        [0.0, 0.0, 0.1, 0.0, 0.0, 0.9],  # 1 -> j
41
        [0.0, 0.0, 0.0, 0.1, 0.0, 0.9],  # 2 -> j
42
        [0.0, 0.0, 0.0, 0.0, 0.1, 0.9],  # 3 -> j
43
        [0.0, 0.0, 0.0, 0.0, 0.0, 1.0],  # 4 -> j
44
        [0.2, 0.1, 0.2, 0.2, 0.2, 0.1],  # END -> j (doesn't matter)
45
    ]
46
)
47

48
# A transition matrix that favors repeated ngrams
49
repeated_ngram_transition_probabilities_0 = torch.tensor(
50
    [  # START 1    2    3    4   END
51
        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0],   # START -> j
52
        [0.0, 0.0, 0.4, 0.6, 0.0, 1e-9],  # 1 -> j
53
        [0.0, 0.0, 0.0, 1.0, 0.0, 1e-9],  # 2 -> j
54
        [0.0, 1.0, 0.0, 0.0, 0.0, 1e-9],  # 3 -> j
55
        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0],   # 4 -> j (not used)
56
        [0.0, 0.0, 0.0, 0.0, 0.0, 1.0],   # END -> j (doesn't matter)
57
    ]
58
)
59

60
# Another transition matrix that favors repeated ngrams
61
repeated_ngram_transition_probabilities_1 = torch.tensor(
62
    [  # START 1    2    3    4   END
63
        [0.0, 0.4, 0.3, 0.2, 0.1, 0.0],  # START -> j
64
        [0.0, 0.4, 0.3, 0.2, 0.1, 0.1],  # 1 -> j
65
        [0.0, 0.0, 0.4, 0.3, 0.2, 0.1],  # 2 -> j
66
        [0.0, 0.0, 0.3, 0.4, 0.2, 0.1],  # 3 -> j
67
        [0.0, 0.0, 0.2, 0.3, 0.4, 0.1],  # 4 -> j
68
        [0.2, 0.1, 0.2, 0.2, 0.2, 0.1],  # END -> j (doesn't matter)
69
    ]
70
)
71
# fmt: on
72

73
log_probabilities = torch.log(
74
    torch.tensor([[0.1, 0.3, 0.3, 0.3, 0.0, 0.0], [0.0, 0.0, 0.4, 0.3, 0.2, 0.1]])
75
)
76

77

78
def get_step_function(
79
    transition_matrix: torch.Tensor, with_timestep: bool = False
80
) -> Union[StepFunctionTypeNoTimestep, StepFunctionTypeWithTimestep]:
81
    def _step_function(
82
        last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
83
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
84
        log_probs_list = []
85
        for last_token in last_predictions:
86
            log_probs = torch.log(transition_matrix[last_token.item()])
87
            log_probs_list.append(log_probs)
88

89
        return torch.stack(log_probs_list), state
90

91
    if not with_timestep:
92
        return _step_function
93

94
    def _step_function_with_timestep(
95
        last_predictions: torch.Tensor,
96
        state: Dict[str, torch.Tensor],
97
        timestep: int,
98
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
99
        return _step_function(last_predictions, state)
100

101
    return _step_function_with_timestep
102

103

104
take_step_no_timestep = get_step_function(transition_probabilities)
105
take_step_with_timestep = get_step_function(transition_probabilities, with_timestep=True)
106
take_short_sequence_step = get_step_function(short_sequence_transition_probabilities)
107

108

109
class BeamSearchTest(AllenNlpTestCase):
110
    def setup_method(self):
111
        super().setup_method()
112
        self.end_index = transition_probabilities.size()[0] - 1
113
        self.beam_search = BeamSearch(self.end_index, max_steps=10, beam_size=3)
114

115
        # This is what the top k should look like for each item in the batch.
116
        self.expected_top_k = np.array([[1, 2, 3, 4, 5], [2, 3, 4, 5, 5], [3, 4, 5, 5, 5]])
117

118
        # This is what the log probs should look like for each item in the batch.
119
        self.expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
120

121
    def _check_results(
122
        self,
123
        batch_size: int = 5,
124
        expected_top_k: np.array = None,
125
        expected_log_probs: np.array = None,
126
        beam_search: BeamSearch = None,
127
        state: Dict[str, torch.Tensor] = None,
128
        take_step=take_step_with_timestep,
129
    ) -> None:
130
        expected_top_k = expected_top_k if expected_top_k is not None else self.expected_top_k
131
        expected_log_probs = (
132
            expected_log_probs if expected_log_probs is not None else self.expected_log_probs
133
        )
134
        state = state or {}
135

136
        beam_search = beam_search or self.beam_search
137
        beam_size = beam_search.beam_size
138

139
        initial_predictions = torch.tensor([0] * batch_size)
140
        top_k, log_probs = beam_search.search(initial_predictions, state, take_step)  # type: ignore
141

142
        # top_k should be shape `(batch_size, beam_size, max_predicted_length)`.
143
        assert list(top_k.size())[:-1] == [batch_size, beam_size]
144
        np.testing.assert_array_equal(top_k[0].numpy(), expected_top_k)
145

146
        # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`.
147
        assert list(log_probs.size()) == [batch_size, beam_size]
148
        np.testing.assert_allclose(log_probs[0].numpy(), expected_log_probs, rtol=1e-6)
149

150
    @pytest.mark.parametrize("step_function", [take_step_with_timestep, take_step_no_timestep])
151
    def test_search(self, step_function):
152
        self._check_results(take_step=step_function)
153

154
    def test_finished_state(self):
155
        state = {}
156
        state["foo"] = torch.tensor([[1, 0, 1], [2, 0, 1], [0, 0, 1], [1, 1, 1], [0, 0, 0]])
157
        # shape: (batch_size, 3)
158

159
        expected_finished_state = {}
160
        expected_finished_state["foo"] = np.array(
161
            [
162
                [1, 0, 1],
163
                [1, 0, 1],
164
                [1, 0, 1],
165
                [2, 0, 1],
166
                [2, 0, 1],
167
                [2, 0, 1],
168
                [0, 0, 1],
169
                [0, 0, 1],
170
                [0, 0, 1],
171
                [1, 1, 1],
172
                [1, 1, 1],
173
                [1, 1, 1],
174
                [0, 0, 0],
175
                [0, 0, 0],
176
                [0, 0, 0],
177
            ]
178
        )
179
        # shape: (batch_size x beam_size, 3)
180

181
        self._check_results(state=state)
182

183
        # check finished state.
184
        for key, array in expected_finished_state.items():
185
            np.testing.assert_allclose(state[key].numpy(), array)
186

187
    def test_diff_shape_state(self):
188
        state = {}
189
        state["decoder_hidden"] = torch.tensor(
190
            [[1, 0, 1], [2, 0, 1], [0, 0, 1], [1, 1, 1], [0, 0, 0]]
191
        )
192
        state["decoder_hidden"] = state["decoder_hidden"].unsqueeze(0).repeat(2, 1, 1)
193
        # shape: (2, batch_size, 3)
194

195
        seq = [
196
            [1, 0, 1],
197
            [1, 0, 1],
198
            [1, 0, 1],
199
            [2, 0, 1],
200
            [2, 0, 1],
201
            [2, 0, 1],
202
            [0, 0, 1],
203
            [0, 0, 1],
204
            [0, 0, 1],
205
            [1, 1, 1],
206
            [1, 1, 1],
207
            [1, 1, 1],
208
            [0, 0, 0],
209
            [0, 0, 0],
210
            [0, 0, 0],
211
        ]
212
        seq = [seq] * 2
213
        expected_finished_state = {}
214
        expected_finished_state["decoder_hidden"] = np.array(seq)
215
        # shape: (2, batch_size x beam_size, 3)
216

217
        self._check_results(state=state)
218

219
        # check finished state.
220
        for key, array in expected_finished_state.items():
221
            np.testing.assert_allclose(state[key].numpy(), array)
222

223
    def test_batch_size_of_one(self):
224
        self._check_results(batch_size=1)
225

226
    def test_greedy_search(self):
227
        beam_search = BeamSearch(self.end_index, beam_size=1)
228
        expected_top_k = np.array([[1, 2, 3, 4, 5]])
229
        expected_log_probs = np.log(np.array([0.4]))
230
        self._check_results(
231
            expected_top_k=expected_top_k,
232
            expected_log_probs=expected_log_probs,
233
            beam_search=beam_search,
234
        )
235

236
    def test_single_step(self):
237
        self.beam_search.max_steps = 1
238
        expected_top_k = np.array([[1], [2], [3]])
239
        expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
240
        self._check_results(
241
            expected_top_k=expected_top_k,
242
            expected_log_probs=expected_log_probs,
243
        )
244

245
    def test_early_stopping(self):
246
        """
247
        Checks case where beam search will reach `max_steps` before finding end tokens.
248
        """
249
        beam_search = BeamSearch(self.end_index, beam_size=3, max_steps=3)
250
        expected_top_k = np.array([[1, 2, 3], [2, 3, 4], [3, 4, 5]])
251
        expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
252
        self._check_results(
253
            expected_top_k=expected_top_k,
254
            expected_log_probs=expected_log_probs,
255
            beam_search=beam_search,
256
        )
257

258
    def test_take_short_sequence_step(self):
259
        """
260
        Tests to ensure the top-k from the short_sequence_transition_probabilities
261
        transition matrix is expected
262
        """
263
        self.beam_search.beam_size = 5
264
        expected_top_k = np.array(
265
            [[5, 5, 5, 5, 5], [1, 5, 5, 5, 5], [1, 2, 5, 5, 5], [1, 2, 3, 5, 5], [1, 2, 3, 4, 5]]
266
        )
267
        expected_log_probs = np.log(np.array([0.9, 0.09, 0.009, 0.0009, 0.0001]))
268
        self._check_results(
269
            expected_top_k=expected_top_k,
270
            expected_log_probs=expected_log_probs,
271
            take_step=take_short_sequence_step,
272
        )
273

274
    def test_min_steps(self):
275
        """
276
        Tests to ensure all output sequences are greater than a specified minimum length.
277
        It uses the `take_short_sequence_step` step function, which favors shorter sequences.
278
        See `test_take_short_sequence_step`.
279
        """
280
        self.beam_search.beam_size = 1
281

282
        # An empty sequence is allowed under this step function
283
        self.beam_search.min_steps = 0
284
        expected_top_k = np.array([[5]])
285
        expected_log_probs = np.log(np.array([0.9]))
286
        with pytest.warns(RuntimeWarning, match="Empty sequences predicted"):
287
            self._check_results(
288
                expected_top_k=expected_top_k,
289
                expected_log_probs=expected_log_probs,
290
                take_step=take_short_sequence_step,
291
            )
292

293
        self.beam_search.min_steps = 1
294
        expected_top_k = np.array([[1, 5]])
295
        expected_log_probs = np.log(np.array([0.09]))
296
        self._check_results(
297
            expected_top_k=expected_top_k,
298
            expected_log_probs=expected_log_probs,
299
            take_step=take_short_sequence_step,
300
        )
301

302
        self.beam_search.min_steps = 2
303
        expected_top_k = np.array([[1, 2, 5]])
304
        expected_log_probs = np.log(np.array([0.009]))
305
        self._check_results(
306
            expected_top_k=expected_top_k,
307
            expected_log_probs=expected_log_probs,
308
            take_step=take_short_sequence_step,
309
        )
310

311
        self.beam_search.beam_size = 3
312
        self.beam_search.min_steps = 2
313
        expected_top_k = np.array([[1, 2, 5, 5, 5], [1, 2, 3, 5, 5], [1, 2, 3, 4, 5]])
314
        expected_log_probs = np.log(np.array([0.009, 0.0009, 0.0001]))
315
        self._check_results(
316
            expected_top_k=expected_top_k,
317
            expected_log_probs=expected_log_probs,
318
            take_step=take_short_sequence_step,
319
        )
320

321
    def test_different_per_node_beam_size(self):
322
        # per_node_beam_size = 1
323
        beam_search = BeamSearch(self.end_index, beam_size=3, per_node_beam_size=1)
324
        self._check_results(beam_search=beam_search)
325

326
        # per_node_beam_size = 2
327
        beam_search = BeamSearch(self.end_index, beam_size=3, per_node_beam_size=2)
328
        self._check_results(beam_search=beam_search)
329

330
    def test_catch_bad_config(self):
331
        """
332
        If `per_node_beam_size` (which defaults to `beam_size`) is larger than
333
        the size of the target vocabulary, `BeamSearch.search` should raise
334
        a ConfigurationError.
335
        """
336
        beam_search = BeamSearch(self.end_index, beam_size=20)
337
        with pytest.raises(ConfigurationError):
338
            self._check_results(beam_search=beam_search)
339

340
    def test_warn_for_bad_log_probs(self):
341
        # The only valid next step from the initial predictions is the end index.
342
        # But with a beam size of 3, the call to `topk` to find the 3 most likely
343
        # next beams will result in 2 new beams that are invalid, in that have probability of 0.
344
        # The beam search should warn us of this.
345
        initial_predictions = torch.LongTensor([self.end_index - 1, self.end_index - 1])
346
        with pytest.warns(RuntimeWarning, match="Negligible log probabilities"):
347
            self.beam_search.search(initial_predictions, {}, take_step_no_timestep)
348

349
    def test_empty_sequences(self):
350
        initial_predictions = torch.LongTensor([self.end_index - 1, self.end_index - 1])
351
        beam_search = BeamSearch(self.end_index, beam_size=1)
352
        with pytest.warns(RuntimeWarning, match="Empty sequences predicted"):
353
            predictions, log_probs = beam_search.search(
354
                initial_predictions, {}, take_step_with_timestep
355
            )
356
        # predictions hould have shape `(batch_size, beam_size, max_predicted_length)`.
357
        assert list(predictions.size()) == [2, 1, 1]
358
        # log probs hould have shape `(batch_size, beam_size)`.
359
        assert list(log_probs.size()) == [2, 1]
360
        assert (predictions == self.end_index).all()
361
        assert (log_probs == 0).all()
362

363
    def test_default_from_params_params(self):
364
        beam_search = BeamSearch.from_params(Params({"beam_size": 2, "end_index": 7}))
365
        assert beam_search.beam_size == 2
366
        assert beam_search._end_index == 7
367

368
    def test_top_p_search(self):
369
        initial_predictions = torch.tensor([0] * 5)
370
        beam_size = 3
371
        take_step = take_step_with_timestep
372
        p_sampler = TopPSampler(p=0.8)
373

374
        top_p, log_probs = BeamSearch(
375
            self.end_index, beam_size=beam_size, max_steps=10, sampler=p_sampler
376
        ).search(initial_predictions, {}, take_step)
377

378
        beam_size = beam_size or 1
379
        batch_size = 5
380

381
        # top_p should be shape `(batch_size, beam_size, max_predicted_length)`.
382
        assert list(top_p.size())[:-1] == [batch_size, beam_size]
383

384
        assert ((0 <= top_p) & (top_p <= 5)).all()
385

386
        # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`.
387
        assert list(log_probs.size()) == [batch_size, beam_size]
388

389
    @pytest.mark.parametrize("p_val", [-1.0, 1.2, 1.1, float("inf")])
390
    def test_p_val(self, p_val):
391
        with pytest.raises(ValueError):
392
            initial_predictions = torch.tensor([0] * 5)
393
            take_step = take_step_with_timestep
394
            beam_size = 3
395
            p_sampler = TopPSampler(p=p_val, with_replacement=True)
396

397
            top_k, log_probs = BeamSearch(
398
                self.end_index, beam_size=beam_size, max_steps=10, sampler=p_sampler
399
            ).search(initial_predictions, {}, take_step)
400

401
    def test_top_k_search(self):
402
        initial_predictions = torch.tensor([0] * 5)
403
        beam_size = 3
404
        take_step = take_step_with_timestep
405
        k_sampler = TopKSampler(k=5, with_replacement=True)
406

407
        top_k, log_probs = BeamSearch(
408
            self.end_index, beam_size=beam_size, max_steps=10, sampler=k_sampler
409
        ).search(initial_predictions, {}, take_step)
410

411
        beam_size = beam_size or 1
412
        batch_size = 5
413

414
        # top_p should be shape `(batch_size, beam_size, max_predicted_length)`.
415
        assert list(top_k.size())[:-1] == [batch_size, beam_size]
416

417
        assert ((0 <= top_k) & (top_k <= 5)).all()
418

419
        # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`.
420
        assert list(log_probs.size()) == [batch_size, beam_size]
421

422
    @pytest.mark.parametrize("k_val", [-1, 0])
423
    def test_k_val(self, k_val):
424
        with pytest.raises(ValueError):
425
            initial_predictions = torch.tensor([0] * 5)
426
            take_step = take_step_with_timestep
427
            beam_size = 3
428
            k_sampler = TopKSampler(k=k_val, with_replacement=True)
429

430
            top_k, log_probs = BeamSearch(
431
                self.end_index, beam_size=beam_size, max_steps=10, sampler=k_sampler
432
            ).search(initial_predictions, {}, take_step)
433

434
    def test_stochastic_beam_search(self):
435
        initial_predictions = torch.tensor([0] * 5)
436
        batch_size = 5
437
        beam_size = 3
438
        take_step = take_step_with_timestep
439

440
        gumbel_sampler = GumbelSampler()
441

442
        top_k, log_probs = BeamSearch(
443
            self.end_index, beam_size=beam_size, max_steps=10, sampler=gumbel_sampler
444
        ).search(initial_predictions, {}, take_step)
445

446
        # top_p should be shape `(batch_size, beam_size, max_predicted_length)`.
447
        assert list(top_k.size())[:-1] == [batch_size, beam_size]
448

449
        assert ((0 <= top_k) & (top_k <= 5)).all()
450

451
        # log_probs should be shape `(batch_size, beam_size, max_predicted_length)`.
452
        assert list(log_probs.size()) == [batch_size, beam_size]
453

454
        # Check to make sure that once the end index is predicted, all subsequent tokens
455
        # must be the end index. This has been tested on toy examples in which
456
        for batch in top_k:
457
            for beam in batch:
458
                reached_end = False
459
                for token in beam:
460
                    if token == self.end_index:
461
                        reached_end = True
462
                    if reached_end:
463
                        assert token == self.end_index
464

465
    def test_params_sampling(self):
466
        beam_search = BeamSearch.from_params(
467
            Params(
468
                {
469
                    "sampler": {
470
                        "type": "top-k",
471
                        "k": 4,
472
                    },
473
                    "beam_size": 2,
474
                    "end_index": 7,
475
                }
476
            )
477
        )
478
        assert beam_search.beam_size == 2
479
        assert beam_search._end_index == 7
480
        assert beam_search.sampler is not None
481

482
    def test_params_p_sampling(self):
483
        beam_search = BeamSearch.from_params(
484
            Params(
485
                {
486
                    "sampler": {
487
                        "type": "top-p",
488
                        "p": 0.8,
489
                    },
490
                    "beam_size": 2,
491
                    "end_index": 7,
492
                }
493
            )
494
        )
495
        assert beam_search.beam_size == 2
496
        assert beam_search._end_index == 7
497
        assert beam_search.sampler is not None
498

499
    def test_multinomial_sampler(self):
500
        sampler = MultinomialSampler(temperature=0.9)
501

502
        probabilities, classes, state = sampler.sample_nodes(log_probabilities, 3, {"foo": "bar"})
503

504
        assert probabilities.size() == classes.size()
505
        assert classes.size() == (2, 3)
506
        assert all([x < 4 for x in classes[0]])
507
        assert all([x > 1 for x in classes[1]])
508

509
    def test_top_k_sampler(self):
510
        sampler = TopKSampler(k=3, temperature=0.9)
511

512
        probabilities, classes, state = sampler.sample_nodes(log_probabilities, 3, {"foo": "bar"})
513

514
        assert probabilities.size() == classes.size()
515
        assert classes.size() == (2, 3)
516

517
        assert all([x > 0 and x < 4 for x in classes[0]])
518
        assert all([x > 1 and x < 5 for x in classes[1]])
519

520
    def test_top_p_sampler(self):
521
        sampler = TopPSampler(p=0.8, temperature=0.9)
522

523
        probabilities, classes, state = sampler.sample_nodes(log_probabilities, 3, {"foo": "bar"})
524

525
        assert probabilities.size() == classes.size()
526
        assert classes.size() == (2, 3)
527

528
        assert all([x > 0 and x < 4 for x in classes[0]])
529
        assert all([x > 1 and x < 5 for x in classes[1]])
530

531
        # Make sure the filtered classes include the first class that exceeds p
532
        sampler = TopPSampler(p=0.7, temperature=1.0)
533

534
        probabilities, classes, state = sampler.sample_nodes(log_probabilities, 2, {"foo": "bar"})
535

536
        assert all([x == 2 or x == 3 or x == 1 for x in classes[0]])
537
        assert all([x == 2 or x == 3 for x in classes[1]])
538

539
    def test_gumbel_sampler(self):
540
        sampler = GumbelSampler()
541
        num_classes = len(log_probabilities[0])
542
        sampler_state = sampler.init_state(log_probabilities, batch_size=2, num_classes=num_classes)
543

544
        log_probs, indices, state = sampler.sample_beams(log_probabilities, 3, sampler_state)
545

546
        assert log_probs.size() == indices.size()
547
        assert indices.size() == (2, 3)
548

549
        # Make sure the probabilities are sorted.
550
        _, sorted_indices = log_probs.sort(dim=-1, descending=True)
551
        assert (sorted_indices == torch.arange(3).unsqueeze(0)).all()
552

553
        assert all([x >= 0 and x < 4 for x in indices[0]])
554
        assert all([x > 1 and x <= 5 for x in indices[1]])
555

556
    def test_length_normalized_sequence_log_prob_scorer(self):
557
        """
558
        Tests to ensure the sequences are normalized by the correct values. The end token is
559
        included in the length. The start token is not.
560
        """
561
        self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer()
562
        expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
563
        length_normalization = np.array([5, 4, 3])
564
        expected_scores = expected_log_probs / length_normalization
565
        self._check_results(expected_log_probs=expected_scores)
566

567
        # Introduce a length penalty
568
        length_penalty = 2.0
569
        self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer(
570
            length_penalty=length_penalty
571
        )
572
        expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
573
        length_normalization = np.array(
574
            [5**length_penalty, 4**length_penalty, 3**length_penalty]
575
        )
576
        expected_scores = expected_log_probs / length_normalization
577
        self._check_results(expected_log_probs=expected_scores)
578

579
        # Pick a length penalty so extreme that the order of the sequences is reversed
580
        length_penalty = -2.0
581
        self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer(
582
            length_penalty=length_penalty
583
        )
584
        expected_top_k = np.array([[3, 4, 5, 5, 5], [2, 3, 4, 5, 5], [1, 2, 3, 4, 5]])
585
        expected_log_probs = np.log(np.array([0.2, 0.3, 0.4]))
586
        length_normalization = np.array(
587
            [3**length_penalty, 4**length_penalty, 5**length_penalty]
588
        )
589
        expected_scores = expected_log_probs / length_normalization
590
        self._check_results(expected_top_k=expected_top_k, expected_log_probs=expected_scores)
591

592
        # Here, we set the max_steps = 4. This prevents the first sequence from finishing,
593
        # so its length does not include the end token, whereas the other sequences do.
594
        length_penalty = 2.0
595
        self.beam_search.max_steps = 4
596
        self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer(
597
            length_penalty=length_penalty
598
        )
599
        expected_top_k = np.array([[1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 5]])
600
        expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
601
        length_normalization = np.array(
602
            [4**length_penalty, 4**length_penalty, 3**length_penalty]
603
        )
604
        expected_scores = expected_log_probs / length_normalization
605
        self._check_results(expected_top_k=expected_top_k, expected_log_probs=expected_scores)
606

607
    def test_repeated_ngram_blocking_constraint_init_state(self):
608
        ngram_size = 3
609
        batch_size = 2
610
        constraint = RepeatedNGramBlockingConstraint(ngram_size)
611

612
        state = constraint.init_state(batch_size)
613
        assert len(state) == batch_size
614
        for beam_states in state:
615
            assert len(beam_states) == 1
616
            beam_state = beam_states[0]
617
            assert len(beam_state.keys()) == 2
618
            assert len(beam_state["current_prefix"]) == 0
619
            assert len(beam_state["seen_ngrams"]) == 0
620

621
    def test_repeated_ngram_blocking_constraint_apply(self):
622
        ngram_size = 3
623
        batch_size = 2
624
        beam_size = 2
625
        num_classes = 10
626
        constraint = RepeatedNGramBlockingConstraint(ngram_size)
627

628
        state = [
629
            [
630
                {"current_prefix": [0, 1], "seen_ngrams": {}},
631
                {"current_prefix": [2, 3], "seen_ngrams": {(2, 3): [4]}},
632
            ],
633
            [
634
                {"current_prefix": [4, 5], "seen_ngrams": {(8, 9): []}},
635
                {"current_prefix": [6, 7], "seen_ngrams": {(6, 7): [0, 1, 2]}},
636
            ],
637
        ]
638
        log_probabilities = torch.rand(batch_size, beam_size, num_classes)
639
        constraint.apply(state, log_probabilities)
640

641
        disallowed_locations = torch.nonzero(
642
            log_probabilities == min_value_of_dtype(log_probabilities.dtype)
643
        ).tolist()
644
        assert len(disallowed_locations) == 4
645
        assert [0, 1, 4] in disallowed_locations
646
        assert [1, 1, 0] in disallowed_locations
647
        assert [1, 1, 1] in disallowed_locations
648
        assert [1, 1, 2] in disallowed_locations
649

650
    def test_repeated_ngram_blocking_constraint_update_state(self):
651
        ngram_size = 3
652
        constraint = RepeatedNGramBlockingConstraint(ngram_size)
653

654
        # We will have [2, 3] -> {5, 6} from batch index 0 and [4, 5] -> {0} and [6, 7] -> {3}
655
        # from batch index
656
        state = [
657
            [
658
                {"current_prefix": [0, 1], "seen_ngrams": {}},
659
                {"current_prefix": [2, 3], "seen_ngrams": {(2, 3): [4]}},
660
            ],
661
            [
662
                {"current_prefix": [4, 5], "seen_ngrams": {(8, 9): []}},
663
                {"current_prefix": [6, 7], "seen_ngrams": {(6, 7): [0, 1, 2]}},
664
            ],
665
        ]
666
        predictions = torch.LongTensor([[5, 6], [0, 3]])
667
        backpointers = torch.LongTensor([[1, 1], [0, 1]])
668

669
        expected_state = [
670
            [
671
                {"current_prefix": [3, 5], "seen_ngrams": {(2, 3): [4, 5]}},
672
                {"current_prefix": [3, 6], "seen_ngrams": {(2, 3): [4, 6]}},
673
            ],
674
            [
675
                {"current_prefix": [5, 0], "seen_ngrams": {(8, 9): [], (4, 5): [0]}},
676
                {"current_prefix": [7, 3], "seen_ngrams": {(6, 7): [0, 1, 2, 3]}},
677
            ],
678
        ]
679
        updated_state = constraint.update_state(state, predictions, backpointers)
680
        assert updated_state == expected_state
681

682
    def test_take_repeated_ngram_step(self):
683
        """
684
        Tests to ensure the top-k from the `repeated_ngram_transition_probabilities_0`
685
        transition matrix is expected. The transitions are:
686

687
            - p(1|start) = 1.0
688
            - p(2|1) = 0.4
689
            - p(3|1) = 0.6
690
            - p(end|1) = 1e-9
691
            - p(3|2) = 1.0
692
            - p(end|2) = 1e-9
693
            - p(1|3) = 1.0
694
            - p(end|3) = 1e-9
695

696
        The probabilities don't add up 1 because of the 1e-9 transitions to end. That doesn't
697
        really matter. Each state just needed some transition to the end probability with a very
698
        small probability to ensure it's possible to reach the end state from there and that it
699
        isn't selected by beam search without a constraint.
700

701
        Below is the beam search tracing for beam size 2. Any sequence below the
702
        line is not selected by beam search. The number that comes before the sequence
703
        is the probability of the sequence.
704

705
        Step 1
706
        1.0: [1]
707

708
        Step 2
709
        0.6: [1, 3]
710
        0.4: [1, 2]
711
        -----
712
        1e-9: [1, 2, end]
713

714
        Step 3
715
        0.6: [1, 3, 1]
716
        0.4: [1, 2, 3]
717
        -----
718
        0.6 * 1e-9: [1, 3, end]
719
        0.4 * 1e-9: [1, 2, end]
720

721
        Step 4
722
        0.4:  [1, 2, 3, 1]
723
        0.36: [1, 3, 1, 3]
724
        -----
725
        0.24:       [1, 3, 1, 2]
726
        0.6 * 1e-9: [1, 3, 1, end]
727
        0.4 * 1e-9: [1, 2, 3, end]
728

729
        Step 5
730
        0.36: [1, 3, 1, 3, 1]
731
        0.24: [1, 2, 3, 1, 3]
732
        -----
733
        0.16:        [1, 2, 3, 1, 2]
734
        0.4 * 1e-9:  [1, 2, 3, 1, end]
735
        0.36 * 1e-9: [1, 3, 1, 3, end]
736
        """
737
        step_function = get_step_function(repeated_ngram_transition_probabilities_0)
738
        self.beam_search.beam_size = 2
739
        self.beam_search.max_steps = 5
740
        expected_top_k = np.array([[1, 3, 1, 3, 1], [1, 2, 3, 1, 3]])
741
        expected_log_probs = np.log(np.array([0.36, 0.24]))
742
        self._check_results(
743
            expected_top_k=expected_top_k,
744
            expected_log_probs=expected_log_probs,
745
            take_step=step_function,
746
        )
747

748
    def test_repeated_ngram_blocking_end_to_end_unigrams(self):
749
        step_function = get_step_function(repeated_ngram_transition_probabilities_0)
750
        self.beam_search.beam_size = 2
751

752
        # Unigrams: On step 3, [1, 3, 1] will be blocked and [1, 3, end] will take its place
753
        self.beam_search.max_steps = 3
754
        self.beam_search.constraints = [RepeatedNGramBlockingConstraint(ngram_size=1)]
755
        expected_top_k = np.array([[1, 2, 3], [1, 3, 5]])
756
        expected_log_probs = np.log(np.array([0.4, 0.6 * 1e-9]))
757
        self._check_results(
758
            expected_top_k=expected_top_k,
759
            expected_log_probs=expected_log_probs,
760
            take_step=step_function,
761
        )
762

763
        step_function = get_step_function(repeated_ngram_transition_probabilities_1)
764
        self.beam_search.max_steps = 5
765
        expected_top_k = np.array([[1, 2, 3, 4, 5], [1, 2, 4, 3, 5]])
766
        expected_log_probs = np.log(
767
            np.array([0.4 * 0.3 * 0.3 * 0.2 * 0.1, 0.4 * 0.3 * 0.2 * 0.3 * 0.1])
768
        )
769
        self._check_results(
770
            expected_top_k=expected_top_k,
771
            expected_log_probs=expected_log_probs,
772
            take_step=step_function,
773
        )
774

775
    def test_repeated_ngram_blocking_end_to_end_bigrams(self):
776
        step_function = get_step_function(repeated_ngram_transition_probabilities_0)
777
        self.beam_search.beam_size = 2
778

779
        # Bigrams: On step 4, [1, 3, 1, 3] will be blocked and [1, 3, 1, 2] will take its place
780
        self.beam_search.max_steps = 4
781
        self.beam_search.constraints = [RepeatedNGramBlockingConstraint(ngram_size=2)]
782
        expected_top_k = np.array([[1, 2, 3, 1], [1, 3, 1, 2]])
783
        expected_log_probs = np.log(np.array([0.4, 0.24]))
784
        self._check_results(
785
            expected_top_k=expected_top_k,
786
            expected_log_probs=expected_log_probs,
787
            take_step=step_function,
788
        )
789

790
    def test_repeated_ngram_blocking_end_to_end_trigrams(self):
791
        step_function = get_step_function(repeated_ngram_transition_probabilities_0)
792
        self.beam_search.beam_size = 2
793

794
        # Trigrams: On step 5, [1, 3, 1, 3, 1] will be blocked and [1, 2, 3, 1, 2] will take its place
795
        self.beam_search.max_steps = 5
796
        self.beam_search.constraints = [RepeatedNGramBlockingConstraint(ngram_size=3)]
797
        expected_top_k = np.array([[1, 2, 3, 1, 3], [1, 2, 3, 1, 2]])
798
        expected_log_probs = np.log(np.array([0.24, 0.16]))
799
        self._check_results(
800
            expected_top_k=expected_top_k,
801
            expected_log_probs=expected_log_probs,
802
            take_step=step_function,
803
        )
804

805
    def test_repeated_ngram_blocking_end_indices(self):
806
        """
807
        Ensures that the ngram blocking does not mess up when one sequence is shorter
808
        than another, which would result in repeated "end" symbols.
809
        """
810
        # We block unigrams, but 5 (the end symbol) is repeated and it does not mess
811
        # up the sequence's probability
812
        step_function = get_step_function(repeated_ngram_transition_probabilities_0)
813
        self.beam_search.beam_size = 2
814
        self.beam_search.constraints = [RepeatedNGramBlockingConstraint(ngram_size=1)]
815
        expected_top_k = np.array([[1, 3, 5, 5], [1, 2, 3, 5]])
816
        expected_log_probs = np.log(np.array([0.6 * 1e-9, 0.4 * 1e-9]))
817
        self._check_results(
818
            expected_top_k=expected_top_k,
819
            expected_log_probs=expected_log_probs,
820
            take_step=step_function,
821
        )
822

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.