paddlenlp

Форк
0
/
test_sentence_aug.py 
74 строки · 3.0 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
import os
15
import unittest
16
from tempfile import TemporaryDirectory
17

18
from paddlenlp.dataaug import (
19
    SentenceBackTranslate,
20
    SentenceContinue,
21
    SentenceGenerate,
22
    SentenceSummarize,
23
)
24
from paddlenlp.transformers import AutoModelForConditionalGeneration, AutoTokenizer
25

26

27
class TestSentAug(unittest.TestCase):
28
    def setUp(self):
29
        self.sequences = ["人类语言是抽象的信息符号。", "而计算机只能处理数值化的信息。"]
30
        self.max_length = 3
31

32
    def test_sent_generate(self):
33
        aug = SentenceGenerate(model_name="__internal_testing__/tiny-random-roformer-sim", max_length=self.max_length)
34
        augmented = aug.augment(self.sequences)
35
        self.assertEqual(len(self.sequences), len(augmented))
36
        self.assertEqual(aug.create_n, len(augmented[0]))
37
        self.assertEqual(aug.create_n, len(augmented[1]))
38

39
    def test_sent_summarize(self):
40
        model = AutoModelForConditionalGeneration.from_pretrained(
41
            "__internal_testing__/tiny-random-mbart", max_length=self.max_length
42
        )
43
        tokenizer = AutoTokenizer.from_pretrained("__internal_testing__/tiny-random-mbart")
44
        model_path = os.path.join(TemporaryDirectory().name, "model")
45
        model.save_pretrained(model_path)
46
        tokenizer.save_pretrained(model_path)
47

48
        aug = SentenceSummarize(task_path=model_path)
49
        augmented = aug.augment(self.sequences)
50
        self.assertEqual(len(self.sequences), len(augmented))
51
        self.assertEqual(aug.create_n, len(augmented[0]))
52
        self.assertEqual(aug.create_n, len(augmented[1]))
53

54
    def test_sent_backtranslate(self):
55
        aug = SentenceBackTranslate(
56
            from_model_name="__internal_testing__/tiny-random-mbart",
57
            to_model_name="__internal_testing__/tiny-random-mbart",
58
            max_length=self.max_length,
59
        )
60
        augmented = aug.augment(self.sequences)
61
        self.assertEqual(len(self.sequences), len(augmented))
62
        self.assertEqual(1, len(augmented[0]))
63
        self.assertEqual(1, len(augmented[1]))
64

65
    def test_sent_continue(self):
66
        aug = SentenceContinue(model_name="__internal_testing__/tiny-random-gpt", max_length=self.max_length)
67
        augmented = aug.augment(self.sequences)
68
        self.assertEqual(len(self.sequences), len(augmented))
69
        self.assertEqual(aug.create_n, len(augmented[0]))
70
        self.assertEqual(aug.create_n, len(augmented[1]))
71

72

73
if __name__ == "__main__":
74
    unittest.main()
75

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.