text-generation-inference
54 строки · 1.8 Кб
1# test_watermark_logits_processor.py
2
3import os4import numpy as np5import torch6from text_generation_server.utils.watermark import WatermarkLogitsProcessor7
8
9GAMMA = os.getenv("WATERMARK_GAMMA", 0.5)10DELTA = os.getenv("WATERMARK_DELTA", 2.0)11
12
13def test_seed_rng():14input_ids = [101, 2036, 3731, 102, 2003, 103]15processor = WatermarkLogitsProcessor()16processor._seed_rng(input_ids)17assert isinstance(processor.rng, torch.Generator)18
19
20def test_get_greenlist_ids():21input_ids = [101, 2036, 3731, 102, 2003, 103]22processor = WatermarkLogitsProcessor()23result = processor._get_greenlist_ids(input_ids, 10, torch.device("cpu"))24assert max(result) <= 1025assert len(result) == int(10 * 0.5)26
27
28def test_calc_greenlist_mask():29processor = WatermarkLogitsProcessor()30scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]])31greenlist_token_ids = torch.tensor([2, 3])32result = processor._calc_greenlist_mask(scores, greenlist_token_ids)33assert result.tolist() == [[False, False, False, False], [False, False, True, True]]34assert result.shape == scores.shape35
36
37def test_bias_greenlist_logits():38processor = WatermarkLogitsProcessor()39scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]])40green_tokens_mask = torch.tensor(41[[False, False, True, True], [False, False, False, True]]42)43greenlist_bias = 2.044result = processor._bias_greenlist_logits(scores, green_tokens_mask, greenlist_bias)45assert np.allclose(result.tolist(), [[0.5, 0.3, 2.2, 2.8], [0.1, 0.2, 0.7, 2.9]])46assert result.shape == scores.shape47
48
49def test_call():50input_ids = [101, 2036, 3731, 102, 2003, 103]51processor = WatermarkLogitsProcessor()52scores = torch.tensor([[0.5, 0.3, 0.2, 0.8], [0.1, 0.2, 0.7, 0.9]])53result = processor(input_ids, scores)54assert result.shape == scores.shape55