17
from paddlenlp.prompt import (
19
PromptDataCollatorWithPadding,
20
PromptModelForSequenceClassification,
23
from paddlenlp.transformers import (
25
AutoModelForSequenceClassification,
30
class PromptModelTest(unittest.TestCase):
33
cls.tokenizer = AutoTokenizer.from_pretrained("__internal_testing__/tiny-random-ernie")
34
cls.model = AutoModelForMaskedLM.from_pretrained("__internal_testing__/tiny-random-ernie")
36
cls.seq_cls_model = AutoModelForSequenceClassification.from_pretrained(
37
"__internal_testing__/tiny-random-ernie", num_labels=cls.num_labels
40
cls.template = AutoTemplate.create_from(
41
prompt="{'soft'}{'text': 'text'}{'mask'}", tokenizer=cls.tokenizer, max_length=512, model=cls.model
43
cls.label_words = {0: "0", 1: "1", 2: "2"}
44
cls.verbalizer = SoftVerbalizer(cls.label_words, cls.tokenizer, cls.model)
45
cls.data_collator = PromptDataCollatorWithPadding(cls.tokenizer, padding=True, return_tensors="pd")
46
cls.prompt_model = PromptModelForSequenceClassification(cls.model, cls.template, cls.verbalizer)
48
def test_sequence_classification_no_labels(self):
49
examples = [{"text": "百度飞桨深度学习框架"}, {"text": "这是一个测试"}]
50
encoded_examples = [self.template(i) for i in examples]
51
logits, hidden_states = self.prompt_model(**self.data_collator(encoded_examples), return_hidden_states=True)
52
self.assertEqual(logits.shape[0], len(examples))
53
self.assertEqual(logits.shape[1], len(self.label_words))
54
self.assertEqual(hidden_states.shape[0], len(examples))
56
model_outputs = self.prompt_model(
57
**self.data_collator(encoded_examples), return_dict=True, return_hidden_states=True
59
self.assertIsNone(model_outputs.loss)
60
self.assertEqual(model_outputs.logits.shape[0], len(examples))
61
self.assertEqual(model_outputs.logits.shape[1], len(self.label_words))
62
self.assertEqual(model_outputs.hidden_states.shape[0], len(examples))
64
def test_sequence_classification_with_labels(self):
65
examples = [{"text": "百度飞桨深度学习框架", "labels": 0}, {"text": "这是一个测试", "labels": 1}]
66
encoded_examples = [self.template(i) for i in examples]
67
loss, logits, hidden_states = self.prompt_model(
68
**self.data_collator(encoded_examples), return_hidden_states=True
70
self.assertIsNotNone(loss)
71
self.assertEqual(logits.shape[0], len(examples))
72
self.assertEqual(logits.shape[1], len(self.label_words))
73
self.assertEqual(hidden_states.shape[0], len(examples))
75
model_outputs = self.prompt_model(
76
**self.data_collator(encoded_examples), return_dict=True, return_hidden_states=True
78
self.assertIsNotNone(model_outputs.loss)
79
self.assertEqual(model_outputs.logits.shape[0], len(examples))
80
self.assertEqual(model_outputs.logits.shape[1], len(self.label_words))
81
self.assertEqual(model_outputs.hidden_states.shape[0], len(examples))
83
def test_efl_no_labels(self):
84
prompt_model = PromptModelForSequenceClassification(self.seq_cls_model, self.template, verbalizer=None)
85
examples = [{"text": "百度飞桨深度学习框架"}, {"text": "这是一个测试"}]
86
encoded_examples = [self.template(i) for i in examples]
87
logits, hidden_states = prompt_model(**self.data_collator(encoded_examples), return_hidden_states=True)
88
self.assertEqual(logits.shape[0], len(examples))
89
self.assertEqual(logits.shape[1], self.num_labels)
90
self.assertEqual(hidden_states.shape[0], len(examples))
92
model_outputs = prompt_model(
93
**self.data_collator(encoded_examples), return_dict=True, return_hidden_states=True
95
self.assertIsNone(model_outputs.loss)
96
self.assertEqual(model_outputs.logits.shape[0], len(examples))
97
self.assertEqual(model_outputs.logits.shape[1], self.num_labels)
98
self.assertEqual(model_outputs.hidden_states.shape[0], len(examples))
100
def test_efl_with_labels(self):
101
prompt_model = PromptModelForSequenceClassification(self.seq_cls_model, self.template, verbalizer=None)
102
examples = [{"text": "百度飞桨深度学习框架", "labels": 0}, {"text": "这是一个测试", "labels": 1}]
103
encoded_examples = [self.template(i) for i in examples]
104
loss, logits, hidden_states = prompt_model(**self.data_collator(encoded_examples), return_hidden_states=True)
105
self.assertIsNotNone(loss)
106
self.assertEqual(logits.shape[0], len(examples))
107
self.assertEqual(logits.shape[1], self.num_labels)
108
self.assertEqual(hidden_states.shape[0], len(examples))
110
model_outputs = prompt_model(
111
**self.data_collator(encoded_examples), return_dict=True, return_hidden_states=True
113
self.assertIsNotNone(model_outputs.loss)
114
self.assertEqual(model_outputs.logits.shape[0], len(examples))
115
self.assertEqual(model_outputs.logits.shape[1], self.num_labels)
116
self.assertEqual(model_outputs.hidden_states.shape[0], len(examples))
119
if __name__ == "__main__":