transformers

Форк
0
/
test_integration.py 
86 строк · 2.9 Кб
1
# coding=utf-8
2
# Copyright 2023 The HuggingFace Team Inc.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
import tempfile
17
import unittest
18

19
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
20
from transformers.testing_utils import (
21
    is_torch_available,
22
    require_optimum,
23
    require_torch,
24
    slow,
25
)
26

27

28
if is_torch_available():
29
    import torch
30

31

32
@require_torch
33
@require_optimum
34
@slow
35
class BetterTransformerIntegrationTest(unittest.TestCase):
36
    # refer to the full test suite in Optimum library:
37
    # https://github.com/huggingface/optimum/tree/main/tests/bettertransformer
38

39
    def test_transform_and_reverse(self):
40
        r"""
41
        Classic tests to simply check if the conversion has been successfull.
42
        """
43
        model_id = "hf-internal-testing/tiny-random-t5"
44
        tokenizer = AutoTokenizer.from_pretrained(model_id)
45
        model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
46

47
        inp = tokenizer("This is me", return_tensors="pt")
48

49
        model = model.to_bettertransformer()
50

51
        self.assertTrue(any("BetterTransformer" in mod.__class__.__name__ for _, mod in model.named_modules()))
52

53
        output = model.generate(**inp)
54

55
        model = model.reverse_bettertransformer()
56

57
        self.assertFalse(any("BetterTransformer" in mod.__class__.__name__ for _, mod in model.named_modules()))
58

59
        with tempfile.TemporaryDirectory() as tmpdirname:
60
            model.save_pretrained(tmpdirname)
61

62
            model_reloaded = AutoModelForSeq2SeqLM.from_pretrained(tmpdirname)
63

64
            self.assertFalse(
65
                any("BetterTransformer" in mod.__class__.__name__ for _, mod in model_reloaded.named_modules())
66
            )
67

68
            output_from_pretrained = model_reloaded.generate(**inp)
69
            self.assertTrue(torch.allclose(output, output_from_pretrained))
70

71
    def test_error_save_pretrained(self):
72
        r"""
73
        The save_pretrained method should raise a ValueError if the model is in BetterTransformer mode.
74
        All should be good if the model is reversed.
75
        """
76
        model_id = "hf-internal-testing/tiny-random-t5"
77
        model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
78

79
        model = model.to_bettertransformer()
80

81
        with tempfile.TemporaryDirectory() as tmpdirname:
82
            with self.assertRaises(ValueError):
83
                model.save_pretrained(tmpdirname)
84

85
            model = model.reverse_bettertransformer()
86
            model.save_pretrained(tmpdirname)
87

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.