16
from tempfile import TemporaryDirectory
18
from paddlenlp.dataaug import (
19
SentenceBackTranslate,
24
from paddlenlp.transformers import AutoModelForConditionalGeneration, AutoTokenizer
27
class TestSentAug(unittest.TestCase):
29
self.sequences = ["人类语言是抽象的信息符号。", "而计算机只能处理数值化的信息。"]
32
def test_sent_generate(self):
33
aug = SentenceGenerate(model_name="__internal_testing__/tiny-random-roformer-sim", max_length=self.max_length)
34
augmented = aug.augment(self.sequences)
35
self.assertEqual(len(self.sequences), len(augmented))
36
self.assertEqual(aug.create_n, len(augmented[0]))
37
self.assertEqual(aug.create_n, len(augmented[1]))
39
def test_sent_summarize(self):
40
model = AutoModelForConditionalGeneration.from_pretrained(
41
"__internal_testing__/tiny-random-mbart", max_length=self.max_length
43
tokenizer = AutoTokenizer.from_pretrained("__internal_testing__/tiny-random-mbart")
44
model_path = os.path.join(TemporaryDirectory().name, "model")
45
model.save_pretrained(model_path)
46
tokenizer.save_pretrained(model_path)
48
aug = SentenceSummarize(task_path=model_path)
49
augmented = aug.augment(self.sequences)
50
self.assertEqual(len(self.sequences), len(augmented))
51
self.assertEqual(aug.create_n, len(augmented[0]))
52
self.assertEqual(aug.create_n, len(augmented[1]))
54
def test_sent_backtranslate(self):
55
aug = SentenceBackTranslate(
56
from_model_name="__internal_testing__/tiny-random-mbart",
57
to_model_name="__internal_testing__/tiny-random-mbart",
58
max_length=self.max_length,
60
augmented = aug.augment(self.sequences)
61
self.assertEqual(len(self.sequences), len(augmented))
62
self.assertEqual(1, len(augmented[0]))
63
self.assertEqual(1, len(augmented[1]))
65
def test_sent_continue(self):
66
aug = SentenceContinue(model_name="__internal_testing__/tiny-random-gpt", max_length=self.max_length)
67
augmented = aug.augment(self.sequences)
68
self.assertEqual(len(self.sequences), len(augmented))
69
self.assertEqual(aug.create_n, len(augmented[0]))
70
self.assertEqual(aug.create_n, len(augmented[1]))
73
if __name__ == "__main__":