transformers
109 строк · 4.7 Кб
1# coding=utf-8
2# Copyright 2023 The HuggingFace Inc. team.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import os17import sys18import unittest19
20
21git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))22sys.path.append(os.path.join(git_repo_path, "utils"))23
24import get_test_info # noqa: E40225from get_test_info import ( # noqa: E40226get_model_to_test_mapping,27get_model_to_tester_mapping,28get_test_to_tester_mapping,29)
30
31
32BERT_TEST_FILE = os.path.join("tests", "models", "bert", "test_modeling_bert.py")33BLIP_TEST_FILE = os.path.join("tests", "models", "blip", "test_modeling_blip.py")34
35
36class GetTestInfoTester(unittest.TestCase):37def test_get_test_to_tester_mapping(self):38bert_test_tester_mapping = get_test_to_tester_mapping(BERT_TEST_FILE)39blip_test_tester_mapping = get_test_to_tester_mapping(BLIP_TEST_FILE)40
41EXPECTED_BERT_MAPPING = {"BertModelTest": "BertModelTester"}42
43EXPECTED_BLIP_MAPPING = {44"BlipModelTest": "BlipModelTester",45"BlipTextImageModelTest": "BlipTextImageModelsModelTester",46"BlipTextModelTest": "BlipTextModelTester",47"BlipTextRetrievalModelTest": "BlipTextRetrievalModelTester",48"BlipVQAModelTest": "BlipVQAModelTester",49"BlipVisionModelTest": "BlipVisionModelTester",50}51
52self.assertEqual(get_test_info.to_json(bert_test_tester_mapping), EXPECTED_BERT_MAPPING)53self.assertEqual(get_test_info.to_json(blip_test_tester_mapping), EXPECTED_BLIP_MAPPING)54
55def test_get_model_to_test_mapping(self):56bert_model_test_mapping = get_model_to_test_mapping(BERT_TEST_FILE)57blip_model_test_mapping = get_model_to_test_mapping(BLIP_TEST_FILE)58
59EXPECTED_BERT_MAPPING = {60"BertForMaskedLM": ["BertModelTest"],61"BertForMultipleChoice": ["BertModelTest"],62"BertForNextSentencePrediction": ["BertModelTest"],63"BertForPreTraining": ["BertModelTest"],64"BertForQuestionAnswering": ["BertModelTest"],65"BertForSequenceClassification": ["BertModelTest"],66"BertForTokenClassification": ["BertModelTest"],67"BertLMHeadModel": ["BertModelTest"],68"BertModel": ["BertModelTest"],69}70
71EXPECTED_BLIP_MAPPING = {72"BlipForConditionalGeneration": ["BlipTextImageModelTest"],73"BlipForImageTextRetrieval": ["BlipTextRetrievalModelTest"],74"BlipForQuestionAnswering": ["BlipVQAModelTest"],75"BlipModel": ["BlipModelTest"],76"BlipTextModel": ["BlipTextModelTest"],77"BlipVisionModel": ["BlipVisionModelTest"],78}79
80self.assertEqual(get_test_info.to_json(bert_model_test_mapping), EXPECTED_BERT_MAPPING)81self.assertEqual(get_test_info.to_json(blip_model_test_mapping), EXPECTED_BLIP_MAPPING)82
83def test_get_model_to_tester_mapping(self):84bert_model_tester_mapping = get_model_to_tester_mapping(BERT_TEST_FILE)85blip_model_tester_mapping = get_model_to_tester_mapping(BLIP_TEST_FILE)86
87EXPECTED_BERT_MAPPING = {88"BertForMaskedLM": ["BertModelTester"],89"BertForMultipleChoice": ["BertModelTester"],90"BertForNextSentencePrediction": ["BertModelTester"],91"BertForPreTraining": ["BertModelTester"],92"BertForQuestionAnswering": ["BertModelTester"],93"BertForSequenceClassification": ["BertModelTester"],94"BertForTokenClassification": ["BertModelTester"],95"BertLMHeadModel": ["BertModelTester"],96"BertModel": ["BertModelTester"],97}98
99EXPECTED_BLIP_MAPPING = {100"BlipForConditionalGeneration": ["BlipTextImageModelsModelTester"],101"BlipForImageTextRetrieval": ["BlipTextRetrievalModelTester"],102"BlipForQuestionAnswering": ["BlipVQAModelTester"],103"BlipModel": ["BlipModelTester"],104"BlipTextModel": ["BlipTextModelTester"],105"BlipVisionModel": ["BlipVisionModelTester"],106}107
108self.assertEqual(get_test_info.to_json(bert_model_tester_mapping), EXPECTED_BERT_MAPPING)109self.assertEqual(get_test_info.to_json(blip_model_tester_mapping), EXPECTED_BLIP_MAPPING)110