optimum-intel

Форк
0
/
test_export.py 
116 строк · 4.6 Кб
1
# Copyright 2024 The HuggingFace Team. All rights reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15

16
import unittest
17
from pathlib import Path
18
from tempfile import TemporaryDirectory
19
from typing import Optional
20

21
from parameterized import parameterized
22
from utils_tests import MODEL_NAMES
23

24
from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
25
from optimum.exporters.openvino import export_from_model
26
from optimum.exporters.tasks import TasksManager
27
from optimum.intel import (
28
    OVLatentConsistencyModelPipeline,
29
    OVModelForAudioClassification,
30
    OVModelForCausalLM,
31
    OVModelForFeatureExtraction,
32
    OVModelForImageClassification,
33
    OVModelForMaskedLM,
34
    OVModelForPix2Struct,
35
    OVModelForQuestionAnswering,
36
    OVModelForSeq2SeqLM,
37
    OVModelForSequenceClassification,
38
    OVModelForSpeechSeq2Seq,
39
    OVModelForTokenClassification,
40
    OVStableDiffusionPipeline,
41
    OVStableDiffusionXLImg2ImgPipeline,
42
    OVStableDiffusionXLPipeline,
43
)
44
from optimum.intel.openvino.modeling_base import OVBaseModel
45
from optimum.utils.save_utils import maybe_load_preprocessors
46

47

48
class ExportModelTest(unittest.TestCase):
49
    SUPPORTED_ARCHITECTURES = {
50
        "bert": OVModelForMaskedLM,
51
        "pix2struct": OVModelForPix2Struct,
52
        "t5": OVModelForSeq2SeqLM,
53
        "bart": OVModelForSeq2SeqLM,
54
        "gpt2": OVModelForCausalLM,
55
        "distilbert": OVModelForQuestionAnswering,
56
        "albert": OVModelForSequenceClassification,
57
        "vit": OVModelForImageClassification,
58
        "roberta": OVModelForTokenClassification,
59
        "wav2vec2": OVModelForAudioClassification,
60
        "whisper": OVModelForSpeechSeq2Seq,
61
        "blenderbot": OVModelForFeatureExtraction,
62
        "stable-diffusion": OVStableDiffusionPipeline,
63
        "stable-diffusion-xl": OVStableDiffusionXLPipeline,
64
        "stable-diffusion-xl-refiner": OVStableDiffusionXLImg2ImgPipeline,
65
        "latent-consistency": OVLatentConsistencyModelPipeline,
66
    }
67

68
    def _openvino_export(
69
        self,
70
        model_type: str,
71
        compression_option: Optional[str] = None,
72
        stateful: bool = True,
73
    ):
74
        auto_model = self.SUPPORTED_ARCHITECTURES[model_type]
75
        task = auto_model.export_feature
76
        model_name = MODEL_NAMES[model_type]
77
        library_name = TasksManager.infer_library_from_model(model_name)
78
        loading_kwargs = {"attn_implementation": "eager"} if model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED else {}
79

80
        if library_name == "timm":
81
            model_class = TasksManager.get_model_class_for_task(task, library=library_name)
82
            model = model_class(f"hf_hub:{model_name}", pretrained=True, exportable=True)
83
            TasksManager.standardize_model_attributes(model_name, model, library_name=library_name)
84
        else:
85
            model = auto_model.auto_model_class.from_pretrained(model_name, **loading_kwargs)
86

87
        if getattr(model.config, "model_type", None) == "pix2struct":
88
            preprocessors = maybe_load_preprocessors(model_name)
89
        else:
90
            preprocessors = None
91

92
        supported_tasks = (task, task + "-with-past") if "text-generation" in task else (task,)
93
        for supported_task in supported_tasks:
94
            with TemporaryDirectory() as tmpdirname:
95
                export_from_model(
96
                    model=model,
97
                    output=Path(tmpdirname),
98
                    task=supported_task,
99
                    preprocessors=preprocessors,
100
                    compression_option=compression_option,
101
                    stateful=stateful,
102
                )
103

104
                use_cache = supported_task.endswith("-with-past")
105
                ov_model = auto_model.from_pretrained(tmpdirname, use_cache=use_cache)
106
                self.assertIsInstance(ov_model, OVBaseModel)
107

108
                if "text-generation" in task:
109
                    self.assertEqual(ov_model.use_cache, use_cache)
110

111
                if task == "text-generation":
112
                    self.assertEqual(ov_model.stateful, stateful and use_cache)
113

114
    @parameterized.expand(SUPPORTED_ARCHITECTURES)
115
    def test_export(self, model_type: str):
116
        self._openvino_export(model_type)
117

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.