optimum-intel
116 строк · 4.6 Кб
1# Copyright 2024 The HuggingFace Team. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15
16import unittest17from pathlib import Path18from tempfile import TemporaryDirectory19from typing import Optional20
21from parameterized import parameterized22from utils_tests import MODEL_NAMES23
24from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED25from optimum.exporters.openvino import export_from_model26from optimum.exporters.tasks import TasksManager27from optimum.intel import (28OVLatentConsistencyModelPipeline,29OVModelForAudioClassification,30OVModelForCausalLM,31OVModelForFeatureExtraction,32OVModelForImageClassification,33OVModelForMaskedLM,34OVModelForPix2Struct,35OVModelForQuestionAnswering,36OVModelForSeq2SeqLM,37OVModelForSequenceClassification,38OVModelForSpeechSeq2Seq,39OVModelForTokenClassification,40OVStableDiffusionPipeline,41OVStableDiffusionXLImg2ImgPipeline,42OVStableDiffusionXLPipeline,43)
44from optimum.intel.openvino.modeling_base import OVBaseModel45from optimum.utils.save_utils import maybe_load_preprocessors46
47
48class ExportModelTest(unittest.TestCase):49SUPPORTED_ARCHITECTURES = {50"bert": OVModelForMaskedLM,51"pix2struct": OVModelForPix2Struct,52"t5": OVModelForSeq2SeqLM,53"bart": OVModelForSeq2SeqLM,54"gpt2": OVModelForCausalLM,55"distilbert": OVModelForQuestionAnswering,56"albert": OVModelForSequenceClassification,57"vit": OVModelForImageClassification,58"roberta": OVModelForTokenClassification,59"wav2vec2": OVModelForAudioClassification,60"whisper": OVModelForSpeechSeq2Seq,61"blenderbot": OVModelForFeatureExtraction,62"stable-diffusion": OVStableDiffusionPipeline,63"stable-diffusion-xl": OVStableDiffusionXLPipeline,64"stable-diffusion-xl-refiner": OVStableDiffusionXLImg2ImgPipeline,65"latent-consistency": OVLatentConsistencyModelPipeline,66}67
68def _openvino_export(69self,70model_type: str,71compression_option: Optional[str] = None,72stateful: bool = True,73):74auto_model = self.SUPPORTED_ARCHITECTURES[model_type]75task = auto_model.export_feature76model_name = MODEL_NAMES[model_type]77library_name = TasksManager.infer_library_from_model(model_name)78loading_kwargs = {"attn_implementation": "eager"} if model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED else {}79
80if library_name == "timm":81model_class = TasksManager.get_model_class_for_task(task, library=library_name)82model = model_class(f"hf_hub:{model_name}", pretrained=True, exportable=True)83TasksManager.standardize_model_attributes(model_name, model, library_name=library_name)84else:85model = auto_model.auto_model_class.from_pretrained(model_name, **loading_kwargs)86
87if getattr(model.config, "model_type", None) == "pix2struct":88preprocessors = maybe_load_preprocessors(model_name)89else:90preprocessors = None91
92supported_tasks = (task, task + "-with-past") if "text-generation" in task else (task,)93for supported_task in supported_tasks:94with TemporaryDirectory() as tmpdirname:95export_from_model(96model=model,97output=Path(tmpdirname),98task=supported_task,99preprocessors=preprocessors,100compression_option=compression_option,101stateful=stateful,102)103
104use_cache = supported_task.endswith("-with-past")105ov_model = auto_model.from_pretrained(tmpdirname, use_cache=use_cache)106self.assertIsInstance(ov_model, OVBaseModel)107
108if "text-generation" in task:109self.assertEqual(ov_model.use_cache, use_cache)110
111if task == "text-generation":112self.assertEqual(ov_model.stateful, stateful and use_cache)113
114@parameterized.expand(SUPPORTED_ARCHITECTURES)115def test_export(self, model_type: str):116self._openvino_export(model_type)117