transformers
86 строк · 2.9 Кб
1# coding=utf-8
2# Copyright 2023 The HuggingFace Team Inc.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import tempfile17import unittest18
19from transformers import AutoModelForSeq2SeqLM, AutoTokenizer20from transformers.testing_utils import (21is_torch_available,22require_optimum,23require_torch,24slow,25)
26
27
28if is_torch_available():29import torch30
31
32@require_torch
33@require_optimum
34@slow
35class BetterTransformerIntegrationTest(unittest.TestCase):36# refer to the full test suite in Optimum library:37# https://github.com/huggingface/optimum/tree/main/tests/bettertransformer38
39def test_transform_and_reverse(self):40r"""41Classic tests to simply check if the conversion has been successfull.
42"""
43model_id = "hf-internal-testing/tiny-random-t5"44tokenizer = AutoTokenizer.from_pretrained(model_id)45model = AutoModelForSeq2SeqLM.from_pretrained(model_id)46
47inp = tokenizer("This is me", return_tensors="pt")48
49model = model.to_bettertransformer()50
51self.assertTrue(any("BetterTransformer" in mod.__class__.__name__ for _, mod in model.named_modules()))52
53output = model.generate(**inp)54
55model = model.reverse_bettertransformer()56
57self.assertFalse(any("BetterTransformer" in mod.__class__.__name__ for _, mod in model.named_modules()))58
59with tempfile.TemporaryDirectory() as tmpdirname:60model.save_pretrained(tmpdirname)61
62model_reloaded = AutoModelForSeq2SeqLM.from_pretrained(tmpdirname)63
64self.assertFalse(65any("BetterTransformer" in mod.__class__.__name__ for _, mod in model_reloaded.named_modules())66)67
68output_from_pretrained = model_reloaded.generate(**inp)69self.assertTrue(torch.allclose(output, output_from_pretrained))70
71def test_error_save_pretrained(self):72r"""73The save_pretrained method should raise a ValueError if the model is in BetterTransformer mode.
74All should be good if the model is reversed.
75"""
76model_id = "hf-internal-testing/tiny-random-t5"77model = AutoModelForSeq2SeqLM.from_pretrained(model_id)78
79model = model.to_bettertransformer()80
81with tempfile.TemporaryDirectory() as tmpdirname:82with self.assertRaises(ValueError):83model.save_pretrained(tmpdirname)84
85model = model.reverse_bettertransformer()86model.save_pretrained(tmpdirname)87