paddlenlp

Форк
0
/
test_prompt_model.py 
120 строк · 5.9 Кб
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import unittest
16

17
from paddlenlp.prompt import (
18
    AutoTemplate,
19
    PromptDataCollatorWithPadding,
20
    PromptModelForSequenceClassification,
21
    SoftVerbalizer,
22
)
23
from paddlenlp.transformers import (
24
    AutoModelForMaskedLM,
25
    AutoModelForSequenceClassification,
26
    AutoTokenizer,
27
)
28

29

30
class PromptModelTest(unittest.TestCase):
31
    @classmethod
32
    def setUpClass(cls):
33
        cls.tokenizer = AutoTokenizer.from_pretrained("__internal_testing__/tiny-random-ernie")
34
        cls.model = AutoModelForMaskedLM.from_pretrained("__internal_testing__/tiny-random-ernie")
35
        cls.num_labels = 2
36
        cls.seq_cls_model = AutoModelForSequenceClassification.from_pretrained(
37
            "__internal_testing__/tiny-random-ernie", num_labels=cls.num_labels
38
        )
39

40
        cls.template = AutoTemplate.create_from(
41
            prompt="{'soft'}{'text': 'text'}{'mask'}", tokenizer=cls.tokenizer, max_length=512, model=cls.model
42
        )
43
        cls.label_words = {0: "0", 1: "1", 2: "2"}
44
        cls.verbalizer = SoftVerbalizer(cls.label_words, cls.tokenizer, cls.model)
45
        cls.data_collator = PromptDataCollatorWithPadding(cls.tokenizer, padding=True, return_tensors="pd")
46
        cls.prompt_model = PromptModelForSequenceClassification(cls.model, cls.template, cls.verbalizer)
47

48
    def test_sequence_classification_no_labels(self):
49
        examples = [{"text": "百度飞桨深度学习框架"}, {"text": "这是一个测试"}]
50
        encoded_examples = [self.template(i) for i in examples]
51
        logits, hidden_states = self.prompt_model(**self.data_collator(encoded_examples), return_hidden_states=True)
52
        self.assertEqual(logits.shape[0], len(examples))
53
        self.assertEqual(logits.shape[1], len(self.label_words))
54
        self.assertEqual(hidden_states.shape[0], len(examples))
55

56
        model_outputs = self.prompt_model(
57
            **self.data_collator(encoded_examples), return_dict=True, return_hidden_states=True
58
        )
59
        self.assertIsNone(model_outputs.loss)
60
        self.assertEqual(model_outputs.logits.shape[0], len(examples))
61
        self.assertEqual(model_outputs.logits.shape[1], len(self.label_words))
62
        self.assertEqual(model_outputs.hidden_states.shape[0], len(examples))
63

64
    def test_sequence_classification_with_labels(self):
65
        examples = [{"text": "百度飞桨深度学习框架", "labels": 0}, {"text": "这是一个测试", "labels": 1}]
66
        encoded_examples = [self.template(i) for i in examples]
67
        loss, logits, hidden_states = self.prompt_model(
68
            **self.data_collator(encoded_examples), return_hidden_states=True
69
        )
70
        self.assertIsNotNone(loss)
71
        self.assertEqual(logits.shape[0], len(examples))
72
        self.assertEqual(logits.shape[1], len(self.label_words))
73
        self.assertEqual(hidden_states.shape[0], len(examples))
74

75
        model_outputs = self.prompt_model(
76
            **self.data_collator(encoded_examples), return_dict=True, return_hidden_states=True
77
        )
78
        self.assertIsNotNone(model_outputs.loss)
79
        self.assertEqual(model_outputs.logits.shape[0], len(examples))
80
        self.assertEqual(model_outputs.logits.shape[1], len(self.label_words))
81
        self.assertEqual(model_outputs.hidden_states.shape[0], len(examples))
82

83
    def test_efl_no_labels(self):
84
        prompt_model = PromptModelForSequenceClassification(self.seq_cls_model, self.template, verbalizer=None)
85
        examples = [{"text": "百度飞桨深度学习框架"}, {"text": "这是一个测试"}]
86
        encoded_examples = [self.template(i) for i in examples]
87
        logits, hidden_states = prompt_model(**self.data_collator(encoded_examples), return_hidden_states=True)
88
        self.assertEqual(logits.shape[0], len(examples))
89
        self.assertEqual(logits.shape[1], self.num_labels)
90
        self.assertEqual(hidden_states.shape[0], len(examples))
91

92
        model_outputs = prompt_model(
93
            **self.data_collator(encoded_examples), return_dict=True, return_hidden_states=True
94
        )
95
        self.assertIsNone(model_outputs.loss)
96
        self.assertEqual(model_outputs.logits.shape[0], len(examples))
97
        self.assertEqual(model_outputs.logits.shape[1], self.num_labels)
98
        self.assertEqual(model_outputs.hidden_states.shape[0], len(examples))
99

100
    def test_efl_with_labels(self):
101
        prompt_model = PromptModelForSequenceClassification(self.seq_cls_model, self.template, verbalizer=None)
102
        examples = [{"text": "百度飞桨深度学习框架", "labels": 0}, {"text": "这是一个测试", "labels": 1}]
103
        encoded_examples = [self.template(i) for i in examples]
104
        loss, logits, hidden_states = prompt_model(**self.data_collator(encoded_examples), return_hidden_states=True)
105
        self.assertIsNotNone(loss)
106
        self.assertEqual(logits.shape[0], len(examples))
107
        self.assertEqual(logits.shape[1], self.num_labels)
108
        self.assertEqual(hidden_states.shape[0], len(examples))
109

110
        model_outputs = prompt_model(
111
            **self.data_collator(encoded_examples), return_dict=True, return_hidden_states=True
112
        )
113
        self.assertIsNotNone(model_outputs.loss)
114
        self.assertEqual(model_outputs.logits.shape[0], len(examples))
115
        self.assertEqual(model_outputs.logits.shape[1], self.num_labels)
116
        self.assertEqual(model_outputs.hidden_states.shape[0], len(examples))
117

118

119
if __name__ == "__main__":
120
    unittest.main()
121

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.