text-generation-inference

Форк
0
54 строки · 1.8 Кб
1
# test_watermark_logits_processor.py
2

3
import os
4
import numpy as np
5
import torch
6
from text_generation_server.utils.watermark import WatermarkLogitsProcessor
7

8

9
GAMMA = os.getenv("WATERMARK_GAMMA", 0.5)
10
DELTA = os.getenv("WATERMARK_DELTA", 2.0)
11

12

13
def test_seed_rng():
14
    input_ids = [101, 2036, 3731, 102, 2003, 103]
15
    processor = WatermarkLogitsProcessor()
16
    processor._seed_rng(input_ids)
17
    assert isinstance(processor.rng, torch.Generator)
18

19

20
def test_get_greenlist_ids():
21
    input_ids = [101, 2036, 3731, 102, 2003, 103]
22
    processor = WatermarkLogitsProcessor()
23
    result = processor._get_greenlist_ids(input_ids, 10, torch.device("cpu"))
24
    assert max(result) <= 10
25
    assert len(result) == int(10 * 0.5)
26

27

28
def test_calc_greenlist_mask():
29
    processor = WatermarkLogitsProcessor()
30
    scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]])
31
    greenlist_token_ids = torch.tensor([2, 3])
32
    result = processor._calc_greenlist_mask(scores, greenlist_token_ids)
33
    assert result.tolist() == [[False, False, False, False], [False, False, True, True]]
34
    assert result.shape == scores.shape
35

36

37
def test_bias_greenlist_logits():
38
    processor = WatermarkLogitsProcessor()
39
    scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]])
40
    green_tokens_mask = torch.tensor(
41
        [[False, False, True, True], [False, False, False, True]]
42
    )
43
    greenlist_bias = 2.0
44
    result = processor._bias_greenlist_logits(scores, green_tokens_mask, greenlist_bias)
45
    assert np.allclose(result.tolist(), [[0.5, 0.3, 2.2, 2.8], [0.1, 0.2, 0.7, 2.9]])
46
    assert result.shape == scores.shape
47

48

49
def test_call():
50
    input_ids = [101, 2036, 3731, 102, 2003, 103]
51
    processor = WatermarkLogitsProcessor()
52
    scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]])
53
    result = processor(input_ids, scores)
54
    assert result.shape == scores.shape
55

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.