transformers

Форк
0
/
test_tokenization_gpt2_tf.py 
131 строка · 5.6 Кб
1
import unittest
2
from pathlib import Path
3
from tempfile import TemporaryDirectory
4

5
from transformers import AutoConfig, TFGPT2LMHeadModel, is_keras_nlp_available, is_tf_available
6
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
7
from transformers.testing_utils import require_keras_nlp, require_tf, slow
8

9

10
if is_tf_available():
11
    import tensorflow as tf
12

13

14
if is_keras_nlp_available():
15
    from transformers.models.gpt2 import TFGPT2Tokenizer
16

17

18
TOKENIZER_CHECKPOINTS = ["openai-community/gpt2"]
19
TINY_MODEL_CHECKPOINT = "openai-community/gpt2"
20

21
if is_tf_available():
22

23
    class ModelToSave(tf.Module):
24
        def __init__(self, tokenizer):
25
            super().__init__()
26
            self.tokenizer = tokenizer
27
            config = AutoConfig.from_pretrained(TINY_MODEL_CHECKPOINT)
28
            self.model = TFGPT2LMHeadModel.from_config(config)
29

30
        @tf.function(input_signature=(tf.TensorSpec((None,), tf.string, name="text"),))
31
        def serving(self, text):
32
            tokenized = self.tokenizer(text)
33
            input_ids_dense = tokenized["input_ids"].to_tensor()
34

35
            input_mask = tf.cast(input_ids_dense > 0, tf.int32)
36
            # input_mask = tf.reshape(input_mask, [-1, MAX_SEQ_LEN])
37

38
            outputs = self.model(input_ids=input_ids_dense, attention_mask=input_mask)["logits"]
39

40
            return outputs
41

42

43
@require_tf
44
@require_keras_nlp
45
class GPTTokenizationTest(unittest.TestCase):
46
    # The TF tokenizers are usually going to be used as pretrained tokenizers from existing model checkpoints,
47
    # so that's what we focus on here.
48

49
    def setUp(self):
50
        super().setUp()
51

52
        self.tokenizers = [GPT2Tokenizer.from_pretrained(checkpoint) for checkpoint in (TOKENIZER_CHECKPOINTS)]
53
        self.tf_tokenizers = [TFGPT2Tokenizer.from_pretrained(checkpoint) for checkpoint in TOKENIZER_CHECKPOINTS]
54
        assert len(self.tokenizers) == len(self.tf_tokenizers)
55

56
        self.test_sentences = [
57
            "This is a straightforward English test sentence.",
58
            "This one has some weird characters\rto\nsee\r\nif  those\u00E9break things.",
59
            "Now we're going to add some Chinese: 一 二 三 一二三",
60
            "And some much more rare Chinese: 齉 堃 齉堃",
61
            "Je vais aussi écrire en français pour tester les accents",
62
            "Classical Irish also has some unusual characters, so in they go: Gaelaċ, ꝼ",
63
        ]
64
        self.paired_sentences = list(zip(self.test_sentences, self.test_sentences[::-1]))
65

66
    def test_output_equivalence(self):
67
        for tokenizer, tf_tokenizer in zip(self.tokenizers, self.tf_tokenizers):
68
            for test_inputs in self.test_sentences:
69
                python_outputs = tokenizer([test_inputs], return_tensors="tf")
70
                tf_outputs = tf_tokenizer([test_inputs])
71

72
                for key in python_outputs.keys():
73
                    # convert them to numpy to avoid messing with ragged tensors
74
                    python_outputs_values = python_outputs[key].numpy()
75
                    tf_outputs_values = tf_outputs[key].numpy()
76

77
                    self.assertTrue(tf.reduce_all(python_outputs_values.shape == tf_outputs_values.shape))
78
                    self.assertTrue(tf.reduce_all(tf.cast(python_outputs_values, tf.int64) == tf_outputs_values))
79

80
    @slow
81
    def test_graph_mode(self):
82
        for tf_tokenizer in self.tf_tokenizers:
83
            compiled_tokenizer = tf.function(tf_tokenizer)
84
            for test_inputs in self.test_sentences:
85
                test_inputs = tf.constant(test_inputs)
86
                compiled_outputs = compiled_tokenizer(test_inputs)
87
                eager_outputs = tf_tokenizer(test_inputs)
88

89
                for key in eager_outputs.keys():
90
                    self.assertTrue(tf.reduce_all(eager_outputs[key] == compiled_outputs[key]))
91

92
    @slow
93
    def test_saved_model(self):
94
        for tf_tokenizer in self.tf_tokenizers:
95
            model = ModelToSave(tokenizer=tf_tokenizer)
96
            test_inputs = tf.convert_to_tensor([self.test_sentences[0]])
97
            out = model.serving(test_inputs)  # Build model with some sample inputs
98
            with TemporaryDirectory() as tempdir:
99
                save_path = Path(tempdir) / "saved.model"
100
                tf.saved_model.save(model, save_path, signatures={"serving_default": model.serving})
101
                loaded_model = tf.saved_model.load(save_path)
102
            loaded_output = loaded_model.signatures["serving_default"](test_inputs)["output_0"]
103
            # We may see small differences because the loaded model is compiled, so we need an epsilon for the test
104
            self.assertTrue(tf.reduce_all(out == loaded_output))
105

106
    @slow
107
    def test_from_config(self):
108
        for tf_tokenizer in self.tf_tokenizers:
109
            test_inputs = tf.convert_to_tensor([self.test_sentences[0]])
110
            out = tf_tokenizer(test_inputs)  # Build model with some sample inputs
111

112
            config = tf_tokenizer.get_config()
113
            model_from_config = TFGPT2Tokenizer.from_config(config)
114
            from_config_output = model_from_config(test_inputs)
115

116
            for key in from_config_output.keys():
117
                self.assertTrue(tf.reduce_all(from_config_output[key] == out[key]))
118

119
    @slow
120
    def test_padding(self):
121
        for tf_tokenizer in self.tf_tokenizers:
122
            # for the test to run
123
            tf_tokenizer.pad_token_id = 123123
124

125
            for max_length in [3, 5, 1024]:
126
                test_inputs = tf.convert_to_tensor([self.test_sentences[0]])
127
                out = tf_tokenizer(test_inputs, max_length=max_length)
128

129
                out_length = out["input_ids"].numpy().shape[1]
130

131
                assert out_length == max_length
132

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.