1
from typing import Dict, Tuple, Union
7
from allennlp.common.checks import ConfigurationError
8
from allennlp.common.testing import AllenNlpTestCase
9
from allennlp.nn.beam_search import (
15
LengthNormalizedSequenceLogProbabilityScorer,
16
RepeatedNGramBlockingConstraint,
17
StepFunctionTypeWithTimestep,
18
StepFunctionTypeNoTimestep,
20
from allennlp.common.params import Params
21
from allennlp.nn.util import min_value_of_dtype
25
transition_probabilities = torch.tensor(
27
[0.0, 0.4, 0.3, 0.2, 0.1, 0.0],
28
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
29
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
30
[0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
31
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
32
[0.2, 0.1, 0.2, 0.2, 0.2, 0.1],
37
short_sequence_transition_probabilities = torch.tensor(
39
[0.0, 0.1, 0.0, 0.0, 0.0, 0.9],
40
[0.0, 0.0, 0.1, 0.0, 0.0, 0.9],
41
[0.0, 0.0, 0.0, 0.1, 0.0, 0.9],
42
[0.0, 0.0, 0.0, 0.0, 0.1, 0.9],
43
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
44
[0.2, 0.1, 0.2, 0.2, 0.2, 0.1],
49
repeated_ngram_transition_probabilities_0 = torch.tensor(
51
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
52
[0.0, 0.0, 0.4, 0.6, 0.0, 1e-9],
53
[0.0, 0.0, 0.0, 1.0, 0.0, 1e-9],
54
[0.0, 1.0, 0.0, 0.0, 0.0, 1e-9],
55
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
56
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
61
repeated_ngram_transition_probabilities_1 = torch.tensor(
63
[0.0, 0.4, 0.3, 0.2, 0.1, 0.0],
64
[0.0, 0.4, 0.3, 0.2, 0.1, 0.1],
65
[0.0, 0.0, 0.4, 0.3, 0.2, 0.1],
66
[0.0, 0.0, 0.3, 0.4, 0.2, 0.1],
67
[0.0, 0.0, 0.2, 0.3, 0.4, 0.1],
68
[0.2, 0.1, 0.2, 0.2, 0.2, 0.1],
73
log_probabilities = torch.log(
74
torch.tensor([[0.1, 0.3, 0.3, 0.3, 0.0, 0.0], [0.0, 0.0, 0.4, 0.3, 0.2, 0.1]])
79
transition_matrix: torch.Tensor, with_timestep: bool = False
80
) -> Union[StepFunctionTypeNoTimestep, StepFunctionTypeWithTimestep]:
82
last_predictions: torch.Tensor, state: Dict[str, torch.Tensor]
83
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
85
for last_token in last_predictions:
86
log_probs = torch.log(transition_matrix[last_token.item()])
87
log_probs_list.append(log_probs)
89
return torch.stack(log_probs_list), state
94
def _step_function_with_timestep(
95
last_predictions: torch.Tensor,
96
state: Dict[str, torch.Tensor],
98
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
99
return _step_function(last_predictions, state)
101
return _step_function_with_timestep
104
take_step_no_timestep = get_step_function(transition_probabilities)
105
take_step_with_timestep = get_step_function(transition_probabilities, with_timestep=True)
106
take_short_sequence_step = get_step_function(short_sequence_transition_probabilities)
109
class BeamSearchTest(AllenNlpTestCase):
110
def setup_method(self):
111
super().setup_method()
112
self.end_index = transition_probabilities.size()[0] - 1
113
self.beam_search = BeamSearch(self.end_index, max_steps=10, beam_size=3)
116
self.expected_top_k = np.array([[1, 2, 3, 4, 5], [2, 3, 4, 5, 5], [3, 4, 5, 5, 5]])
119
self.expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
124
expected_top_k: np.array = None,
125
expected_log_probs: np.array = None,
126
beam_search: BeamSearch = None,
127
state: Dict[str, torch.Tensor] = None,
128
take_step=take_step_with_timestep,
130
expected_top_k = expected_top_k if expected_top_k is not None else self.expected_top_k
131
expected_log_probs = (
132
expected_log_probs if expected_log_probs is not None else self.expected_log_probs
136
beam_search = beam_search or self.beam_search
137
beam_size = beam_search.beam_size
139
initial_predictions = torch.tensor([0] * batch_size)
140
top_k, log_probs = beam_search.search(initial_predictions, state, take_step)
143
assert list(top_k.size())[:-1] == [batch_size, beam_size]
144
np.testing.assert_array_equal(top_k[0].numpy(), expected_top_k)
147
assert list(log_probs.size()) == [batch_size, beam_size]
148
np.testing.assert_allclose(log_probs[0].numpy(), expected_log_probs, rtol=1e-6)
150
@pytest.mark.parametrize("step_function", [take_step_with_timestep, take_step_no_timestep])
151
def test_search(self, step_function):
152
self._check_results(take_step=step_function)
154
def test_finished_state(self):
156
state["foo"] = torch.tensor([[1, 0, 1], [2, 0, 1], [0, 0, 1], [1, 1, 1], [0, 0, 0]])
159
expected_finished_state = {}
160
expected_finished_state["foo"] = np.array(
181
self._check_results(state=state)
184
for key, array in expected_finished_state.items():
185
np.testing.assert_allclose(state[key].numpy(), array)
187
def test_diff_shape_state(self):
189
state["decoder_hidden"] = torch.tensor(
190
[[1, 0, 1], [2, 0, 1], [0, 0, 1], [1, 1, 1], [0, 0, 0]]
192
state["decoder_hidden"] = state["decoder_hidden"].unsqueeze(0).repeat(2, 1, 1)
213
expected_finished_state = {}
214
expected_finished_state["decoder_hidden"] = np.array(seq)
217
self._check_results(state=state)
220
for key, array in expected_finished_state.items():
221
np.testing.assert_allclose(state[key].numpy(), array)
223
def test_batch_size_of_one(self):
224
self._check_results(batch_size=1)
226
def test_greedy_search(self):
227
beam_search = BeamSearch(self.end_index, beam_size=1)
228
expected_top_k = np.array([[1, 2, 3, 4, 5]])
229
expected_log_probs = np.log(np.array([0.4]))
231
expected_top_k=expected_top_k,
232
expected_log_probs=expected_log_probs,
233
beam_search=beam_search,
236
def test_single_step(self):
237
self.beam_search.max_steps = 1
238
expected_top_k = np.array([[1], [2], [3]])
239
expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
241
expected_top_k=expected_top_k,
242
expected_log_probs=expected_log_probs,
245
def test_early_stopping(self):
247
Checks case where beam search will reach `max_steps` before finding end tokens.
249
beam_search = BeamSearch(self.end_index, beam_size=3, max_steps=3)
250
expected_top_k = np.array([[1, 2, 3], [2, 3, 4], [3, 4, 5]])
251
expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
253
expected_top_k=expected_top_k,
254
expected_log_probs=expected_log_probs,
255
beam_search=beam_search,
258
def test_take_short_sequence_step(self):
260
Tests to ensure the top-k from the short_sequence_transition_probabilities
261
transition matrix is expected
263
self.beam_search.beam_size = 5
264
expected_top_k = np.array(
265
[[5, 5, 5, 5, 5], [1, 5, 5, 5, 5], [1, 2, 5, 5, 5], [1, 2, 3, 5, 5], [1, 2, 3, 4, 5]]
267
expected_log_probs = np.log(np.array([0.9, 0.09, 0.009, 0.0009, 0.0001]))
269
expected_top_k=expected_top_k,
270
expected_log_probs=expected_log_probs,
271
take_step=take_short_sequence_step,
274
def test_min_steps(self):
276
Tests to ensure all output sequences are greater than a specified minimum length.
277
It uses the `take_short_sequence_step` step function, which favors shorter sequences.
278
See `test_take_short_sequence_step`.
280
self.beam_search.beam_size = 1
283
self.beam_search.min_steps = 0
284
expected_top_k = np.array([[5]])
285
expected_log_probs = np.log(np.array([0.9]))
286
with pytest.warns(RuntimeWarning, match="Empty sequences predicted"):
288
expected_top_k=expected_top_k,
289
expected_log_probs=expected_log_probs,
290
take_step=take_short_sequence_step,
293
self.beam_search.min_steps = 1
294
expected_top_k = np.array([[1, 5]])
295
expected_log_probs = np.log(np.array([0.09]))
297
expected_top_k=expected_top_k,
298
expected_log_probs=expected_log_probs,
299
take_step=take_short_sequence_step,
302
self.beam_search.min_steps = 2
303
expected_top_k = np.array([[1, 2, 5]])
304
expected_log_probs = np.log(np.array([0.009]))
306
expected_top_k=expected_top_k,
307
expected_log_probs=expected_log_probs,
308
take_step=take_short_sequence_step,
311
self.beam_search.beam_size = 3
312
self.beam_search.min_steps = 2
313
expected_top_k = np.array([[1, 2, 5, 5, 5], [1, 2, 3, 5, 5], [1, 2, 3, 4, 5]])
314
expected_log_probs = np.log(np.array([0.009, 0.0009, 0.0001]))
316
expected_top_k=expected_top_k,
317
expected_log_probs=expected_log_probs,
318
take_step=take_short_sequence_step,
321
def test_different_per_node_beam_size(self):
323
beam_search = BeamSearch(self.end_index, beam_size=3, per_node_beam_size=1)
324
self._check_results(beam_search=beam_search)
327
beam_search = BeamSearch(self.end_index, beam_size=3, per_node_beam_size=2)
328
self._check_results(beam_search=beam_search)
330
def test_catch_bad_config(self):
332
If `per_node_beam_size` (which defaults to `beam_size`) is larger than
333
the size of the target vocabulary, `BeamSearch.search` should raise
334
a ConfigurationError.
336
beam_search = BeamSearch(self.end_index, beam_size=20)
337
with pytest.raises(ConfigurationError):
338
self._check_results(beam_search=beam_search)
340
def test_warn_for_bad_log_probs(self):
345
initial_predictions = torch.LongTensor([self.end_index - 1, self.end_index - 1])
346
with pytest.warns(RuntimeWarning, match="Negligible log probabilities"):
347
self.beam_search.search(initial_predictions, {}, take_step_no_timestep)
349
def test_empty_sequences(self):
350
initial_predictions = torch.LongTensor([self.end_index - 1, self.end_index - 1])
351
beam_search = BeamSearch(self.end_index, beam_size=1)
352
with pytest.warns(RuntimeWarning, match="Empty sequences predicted"):
353
predictions, log_probs = beam_search.search(
354
initial_predictions, {}, take_step_with_timestep
357
assert list(predictions.size()) == [2, 1, 1]
359
assert list(log_probs.size()) == [2, 1]
360
assert (predictions == self.end_index).all()
361
assert (log_probs == 0).all()
363
def test_default_from_params_params(self):
364
beam_search = BeamSearch.from_params(Params({"beam_size": 2, "end_index": 7}))
365
assert beam_search.beam_size == 2
366
assert beam_search._end_index == 7
368
def test_top_p_search(self):
369
initial_predictions = torch.tensor([0] * 5)
371
take_step = take_step_with_timestep
372
p_sampler = TopPSampler(p=0.8)
374
top_p, log_probs = BeamSearch(
375
self.end_index, beam_size=beam_size, max_steps=10, sampler=p_sampler
376
).search(initial_predictions, {}, take_step)
378
beam_size = beam_size or 1
382
assert list(top_p.size())[:-1] == [batch_size, beam_size]
384
assert ((0 <= top_p) & (top_p <= 5)).all()
387
assert list(log_probs.size()) == [batch_size, beam_size]
389
@pytest.mark.parametrize("p_val", [-1.0, 1.2, 1.1, float("inf")])
390
def test_p_val(self, p_val):
391
with pytest.raises(ValueError):
392
initial_predictions = torch.tensor([0] * 5)
393
take_step = take_step_with_timestep
395
p_sampler = TopPSampler(p=p_val, with_replacement=True)
397
top_k, log_probs = BeamSearch(
398
self.end_index, beam_size=beam_size, max_steps=10, sampler=p_sampler
399
).search(initial_predictions, {}, take_step)
401
def test_top_k_search(self):
402
initial_predictions = torch.tensor([0] * 5)
404
take_step = take_step_with_timestep
405
k_sampler = TopKSampler(k=5, with_replacement=True)
407
top_k, log_probs = BeamSearch(
408
self.end_index, beam_size=beam_size, max_steps=10, sampler=k_sampler
409
).search(initial_predictions, {}, take_step)
411
beam_size = beam_size or 1
415
assert list(top_k.size())[:-1] == [batch_size, beam_size]
417
assert ((0 <= top_k) & (top_k <= 5)).all()
420
assert list(log_probs.size()) == [batch_size, beam_size]
422
@pytest.mark.parametrize("k_val", [-1, 0])
423
def test_k_val(self, k_val):
424
with pytest.raises(ValueError):
425
initial_predictions = torch.tensor([0] * 5)
426
take_step = take_step_with_timestep
428
k_sampler = TopKSampler(k=k_val, with_replacement=True)
430
top_k, log_probs = BeamSearch(
431
self.end_index, beam_size=beam_size, max_steps=10, sampler=k_sampler
432
).search(initial_predictions, {}, take_step)
434
def test_stochastic_beam_search(self):
435
initial_predictions = torch.tensor([0] * 5)
438
take_step = take_step_with_timestep
440
gumbel_sampler = GumbelSampler()
442
top_k, log_probs = BeamSearch(
443
self.end_index, beam_size=beam_size, max_steps=10, sampler=gumbel_sampler
444
).search(initial_predictions, {}, take_step)
447
assert list(top_k.size())[:-1] == [batch_size, beam_size]
449
assert ((0 <= top_k) & (top_k <= 5)).all()
452
assert list(log_probs.size()) == [batch_size, beam_size]
460
if token == self.end_index:
463
assert token == self.end_index
465
def test_params_sampling(self):
466
beam_search = BeamSearch.from_params(
478
assert beam_search.beam_size == 2
479
assert beam_search._end_index == 7
480
assert beam_search.sampler is not None
482
def test_params_p_sampling(self):
483
beam_search = BeamSearch.from_params(
495
assert beam_search.beam_size == 2
496
assert beam_search._end_index == 7
497
assert beam_search.sampler is not None
499
def test_multinomial_sampler(self):
500
sampler = MultinomialSampler(temperature=0.9)
502
probabilities, classes, state = sampler.sample_nodes(log_probabilities, 3, {"foo": "bar"})
504
assert probabilities.size() == classes.size()
505
assert classes.size() == (2, 3)
506
assert all([x < 4 for x in classes[0]])
507
assert all([x > 1 for x in classes[1]])
509
def test_top_k_sampler(self):
510
sampler = TopKSampler(k=3, temperature=0.9)
512
probabilities, classes, state = sampler.sample_nodes(log_probabilities, 3, {"foo": "bar"})
514
assert probabilities.size() == classes.size()
515
assert classes.size() == (2, 3)
517
assert all([x > 0 and x < 4 for x in classes[0]])
518
assert all([x > 1 and x < 5 for x in classes[1]])
520
def test_top_p_sampler(self):
521
sampler = TopPSampler(p=0.8, temperature=0.9)
523
probabilities, classes, state = sampler.sample_nodes(log_probabilities, 3, {"foo": "bar"})
525
assert probabilities.size() == classes.size()
526
assert classes.size() == (2, 3)
528
assert all([x > 0 and x < 4 for x in classes[0]])
529
assert all([x > 1 and x < 5 for x in classes[1]])
532
sampler = TopPSampler(p=0.7, temperature=1.0)
534
probabilities, classes, state = sampler.sample_nodes(log_probabilities, 2, {"foo": "bar"})
536
assert all([x == 2 or x == 3 or x == 1 for x in classes[0]])
537
assert all([x == 2 or x == 3 for x in classes[1]])
539
def test_gumbel_sampler(self):
540
sampler = GumbelSampler()
541
num_classes = len(log_probabilities[0])
542
sampler_state = sampler.init_state(log_probabilities, batch_size=2, num_classes=num_classes)
544
log_probs, indices, state = sampler.sample_beams(log_probabilities, 3, sampler_state)
546
assert log_probs.size() == indices.size()
547
assert indices.size() == (2, 3)
550
_, sorted_indices = log_probs.sort(dim=-1, descending=True)
551
assert (sorted_indices == torch.arange(3).unsqueeze(0)).all()
553
assert all([x >= 0 and x < 4 for x in indices[0]])
554
assert all([x > 1 and x <= 5 for x in indices[1]])
556
def test_length_normalized_sequence_log_prob_scorer(self):
558
Tests to ensure the sequences are normalized by the correct values. The end token is
559
included in the length. The start token is not.
561
self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer()
562
expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
563
length_normalization = np.array([5, 4, 3])
564
expected_scores = expected_log_probs / length_normalization
565
self._check_results(expected_log_probs=expected_scores)
569
self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer(
570
length_penalty=length_penalty
572
expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
573
length_normalization = np.array(
574
[5**length_penalty, 4**length_penalty, 3**length_penalty]
576
expected_scores = expected_log_probs / length_normalization
577
self._check_results(expected_log_probs=expected_scores)
580
length_penalty = -2.0
581
self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer(
582
length_penalty=length_penalty
584
expected_top_k = np.array([[3, 4, 5, 5, 5], [2, 3, 4, 5, 5], [1, 2, 3, 4, 5]])
585
expected_log_probs = np.log(np.array([0.2, 0.3, 0.4]))
586
length_normalization = np.array(
587
[3**length_penalty, 4**length_penalty, 5**length_penalty]
589
expected_scores = expected_log_probs / length_normalization
590
self._check_results(expected_top_k=expected_top_k, expected_log_probs=expected_scores)
595
self.beam_search.max_steps = 4
596
self.beam_search.final_sequence_scorer = LengthNormalizedSequenceLogProbabilityScorer(
597
length_penalty=length_penalty
599
expected_top_k = np.array([[1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 5]])
600
expected_log_probs = np.log(np.array([0.4, 0.3, 0.2]))
601
length_normalization = np.array(
602
[4**length_penalty, 4**length_penalty, 3**length_penalty]
604
expected_scores = expected_log_probs / length_normalization
605
self._check_results(expected_top_k=expected_top_k, expected_log_probs=expected_scores)
607
def test_repeated_ngram_blocking_constraint_init_state(self):
610
constraint = RepeatedNGramBlockingConstraint(ngram_size)
612
state = constraint.init_state(batch_size)
613
assert len(state) == batch_size
614
for beam_states in state:
615
assert len(beam_states) == 1
616
beam_state = beam_states[0]
617
assert len(beam_state.keys()) == 2
618
assert len(beam_state["current_prefix"]) == 0
619
assert len(beam_state["seen_ngrams"]) == 0
621
def test_repeated_ngram_blocking_constraint_apply(self):
626
constraint = RepeatedNGramBlockingConstraint(ngram_size)
630
{"current_prefix": [0, 1], "seen_ngrams": {}},
631
{"current_prefix": [2, 3], "seen_ngrams": {(2, 3): [4]}},
634
{"current_prefix": [4, 5], "seen_ngrams": {(8, 9): []}},
635
{"current_prefix": [6, 7], "seen_ngrams": {(6, 7): [0, 1, 2]}},
638
log_probabilities = torch.rand(batch_size, beam_size, num_classes)
639
constraint.apply(state, log_probabilities)
641
disallowed_locations = torch.nonzero(
642
log_probabilities == min_value_of_dtype(log_probabilities.dtype)
644
assert len(disallowed_locations) == 4
645
assert [0, 1, 4] in disallowed_locations
646
assert [1, 1, 0] in disallowed_locations
647
assert [1, 1, 1] in disallowed_locations
648
assert [1, 1, 2] in disallowed_locations
650
def test_repeated_ngram_blocking_constraint_update_state(self):
652
constraint = RepeatedNGramBlockingConstraint(ngram_size)
658
{"current_prefix": [0, 1], "seen_ngrams": {}},
659
{"current_prefix": [2, 3], "seen_ngrams": {(2, 3): [4]}},
662
{"current_prefix": [4, 5], "seen_ngrams": {(8, 9): []}},
663
{"current_prefix": [6, 7], "seen_ngrams": {(6, 7): [0, 1, 2]}},
666
predictions = torch.LongTensor([[5, 6], [0, 3]])
667
backpointers = torch.LongTensor([[1, 1], [0, 1]])
671
{"current_prefix": [3, 5], "seen_ngrams": {(2, 3): [4, 5]}},
672
{"current_prefix": [3, 6], "seen_ngrams": {(2, 3): [4, 6]}},
675
{"current_prefix": [5, 0], "seen_ngrams": {(8, 9): [], (4, 5): [0]}},
676
{"current_prefix": [7, 3], "seen_ngrams": {(6, 7): [0, 1, 2, 3]}},
679
updated_state = constraint.update_state(state, predictions, backpointers)
680
assert updated_state == expected_state
682
def test_take_repeated_ngram_step(self):
684
Tests to ensure the top-k from the `repeated_ngram_transition_probabilities_0`
685
transition matrix is expected. The transitions are:
696
The probabilities don't add up 1 because of the 1e-9 transitions to end. That doesn't
697
really matter. Each state just needed some transition to the end probability with a very
698
small probability to ensure it's possible to reach the end state from there and that it
699
isn't selected by beam search without a constraint.
701
Below is the beam search tracing for beam size 2. Any sequence below the
702
line is not selected by beam search. The number that comes before the sequence
703
is the probability of the sequence.
718
0.6 * 1e-9: [1, 3, end]
719
0.4 * 1e-9: [1, 2, end]
726
0.6 * 1e-9: [1, 3, 1, end]
727
0.4 * 1e-9: [1, 2, 3, end]
730
0.36: [1, 3, 1, 3, 1]
731
0.24: [1, 2, 3, 1, 3]
733
0.16: [1, 2, 3, 1, 2]
734
0.4 * 1e-9: [1, 2, 3, 1, end]
735
0.36 * 1e-9: [1, 3, 1, 3, end]
737
step_function = get_step_function(repeated_ngram_transition_probabilities_0)
738
self.beam_search.beam_size = 2
739
self.beam_search.max_steps = 5
740
expected_top_k = np.array([[1, 3, 1, 3, 1], [1, 2, 3, 1, 3]])
741
expected_log_probs = np.log(np.array([0.36, 0.24]))
743
expected_top_k=expected_top_k,
744
expected_log_probs=expected_log_probs,
745
take_step=step_function,
748
def test_repeated_ngram_blocking_end_to_end_unigrams(self):
749
step_function = get_step_function(repeated_ngram_transition_probabilities_0)
750
self.beam_search.beam_size = 2
753
self.beam_search.max_steps = 3
754
self.beam_search.constraints = [RepeatedNGramBlockingConstraint(ngram_size=1)]
755
expected_top_k = np.array([[1, 2, 3], [1, 3, 5]])
756
expected_log_probs = np.log(np.array([0.4, 0.6 * 1e-9]))
758
expected_top_k=expected_top_k,
759
expected_log_probs=expected_log_probs,
760
take_step=step_function,
763
step_function = get_step_function(repeated_ngram_transition_probabilities_1)
764
self.beam_search.max_steps = 5
765
expected_top_k = np.array([[1, 2, 3, 4, 5], [1, 2, 4, 3, 5]])
766
expected_log_probs = np.log(
767
np.array([0.4 * 0.3 * 0.3 * 0.2 * 0.1, 0.4 * 0.3 * 0.2 * 0.3 * 0.1])
770
expected_top_k=expected_top_k,
771
expected_log_probs=expected_log_probs,
772
take_step=step_function,
775
def test_repeated_ngram_blocking_end_to_end_bigrams(self):
776
step_function = get_step_function(repeated_ngram_transition_probabilities_0)
777
self.beam_search.beam_size = 2
780
self.beam_search.max_steps = 4
781
self.beam_search.constraints = [RepeatedNGramBlockingConstraint(ngram_size=2)]
782
expected_top_k = np.array([[1, 2, 3, 1], [1, 3, 1, 2]])
783
expected_log_probs = np.log(np.array([0.4, 0.24]))
785
expected_top_k=expected_top_k,
786
expected_log_probs=expected_log_probs,
787
take_step=step_function,
790
def test_repeated_ngram_blocking_end_to_end_trigrams(self):
791
step_function = get_step_function(repeated_ngram_transition_probabilities_0)
792
self.beam_search.beam_size = 2
795
self.beam_search.max_steps = 5
796
self.beam_search.constraints = [RepeatedNGramBlockingConstraint(ngram_size=3)]
797
expected_top_k = np.array([[1, 2, 3, 1, 3], [1, 2, 3, 1, 2]])
798
expected_log_probs = np.log(np.array([0.24, 0.16]))
800
expected_top_k=expected_top_k,
801
expected_log_probs=expected_log_probs,
802
take_step=step_function,
805
def test_repeated_ngram_blocking_end_indices(self):
807
Ensures that the ngram blocking does not mess up when one sequence is shorter
808
than another, which would result in repeated "end" symbols.
812
step_function = get_step_function(repeated_ngram_transition_probabilities_0)
813
self.beam_search.beam_size = 2
814
self.beam_search.constraints = [RepeatedNGramBlockingConstraint(ngram_size=1)]
815
expected_top_k = np.array([[1, 3, 5, 5], [1, 2, 3, 5]])
816
expected_log_probs = np.log(np.array([0.6 * 1e-9, 0.4 * 1e-9]))
818
expected_top_k=expected_top_k,
819
expected_log_probs=expected_log_probs,
820
take_step=step_function,