transformers

Форк
0
/
test_get_test_info.py 
109 строк · 4.7 Кб
1
# coding=utf-8
2
# Copyright 2023 The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
import os
17
import sys
18
import unittest
19

20

21
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
22
sys.path.append(os.path.join(git_repo_path, "utils"))
23

24
import get_test_info  # noqa: E402
25
from get_test_info import (  # noqa: E402
26
    get_model_to_test_mapping,
27
    get_model_to_tester_mapping,
28
    get_test_to_tester_mapping,
29
)
30

31

32
BERT_TEST_FILE = os.path.join("tests", "models", "bert", "test_modeling_bert.py")
33
BLIP_TEST_FILE = os.path.join("tests", "models", "blip", "test_modeling_blip.py")
34

35

36
class GetTestInfoTester(unittest.TestCase):
37
    def test_get_test_to_tester_mapping(self):
38
        bert_test_tester_mapping = get_test_to_tester_mapping(BERT_TEST_FILE)
39
        blip_test_tester_mapping = get_test_to_tester_mapping(BLIP_TEST_FILE)
40

41
        EXPECTED_BERT_MAPPING = {"BertModelTest": "BertModelTester"}
42

43
        EXPECTED_BLIP_MAPPING = {
44
            "BlipModelTest": "BlipModelTester",
45
            "BlipTextImageModelTest": "BlipTextImageModelsModelTester",
46
            "BlipTextModelTest": "BlipTextModelTester",
47
            "BlipTextRetrievalModelTest": "BlipTextRetrievalModelTester",
48
            "BlipVQAModelTest": "BlipVQAModelTester",
49
            "BlipVisionModelTest": "BlipVisionModelTester",
50
        }
51

52
        self.assertEqual(get_test_info.to_json(bert_test_tester_mapping), EXPECTED_BERT_MAPPING)
53
        self.assertEqual(get_test_info.to_json(blip_test_tester_mapping), EXPECTED_BLIP_MAPPING)
54

55
    def test_get_model_to_test_mapping(self):
56
        bert_model_test_mapping = get_model_to_test_mapping(BERT_TEST_FILE)
57
        blip_model_test_mapping = get_model_to_test_mapping(BLIP_TEST_FILE)
58

59
        EXPECTED_BERT_MAPPING = {
60
            "BertForMaskedLM": ["BertModelTest"],
61
            "BertForMultipleChoice": ["BertModelTest"],
62
            "BertForNextSentencePrediction": ["BertModelTest"],
63
            "BertForPreTraining": ["BertModelTest"],
64
            "BertForQuestionAnswering": ["BertModelTest"],
65
            "BertForSequenceClassification": ["BertModelTest"],
66
            "BertForTokenClassification": ["BertModelTest"],
67
            "BertLMHeadModel": ["BertModelTest"],
68
            "BertModel": ["BertModelTest"],
69
        }
70

71
        EXPECTED_BLIP_MAPPING = {
72
            "BlipForConditionalGeneration": ["BlipTextImageModelTest"],
73
            "BlipForImageTextRetrieval": ["BlipTextRetrievalModelTest"],
74
            "BlipForQuestionAnswering": ["BlipVQAModelTest"],
75
            "BlipModel": ["BlipModelTest"],
76
            "BlipTextModel": ["BlipTextModelTest"],
77
            "BlipVisionModel": ["BlipVisionModelTest"],
78
        }
79

80
        self.assertEqual(get_test_info.to_json(bert_model_test_mapping), EXPECTED_BERT_MAPPING)
81
        self.assertEqual(get_test_info.to_json(blip_model_test_mapping), EXPECTED_BLIP_MAPPING)
82

83
    def test_get_model_to_tester_mapping(self):
84
        bert_model_tester_mapping = get_model_to_tester_mapping(BERT_TEST_FILE)
85
        blip_model_tester_mapping = get_model_to_tester_mapping(BLIP_TEST_FILE)
86

87
        EXPECTED_BERT_MAPPING = {
88
            "BertForMaskedLM": ["BertModelTester"],
89
            "BertForMultipleChoice": ["BertModelTester"],
90
            "BertForNextSentencePrediction": ["BertModelTester"],
91
            "BertForPreTraining": ["BertModelTester"],
92
            "BertForQuestionAnswering": ["BertModelTester"],
93
            "BertForSequenceClassification": ["BertModelTester"],
94
            "BertForTokenClassification": ["BertModelTester"],
95
            "BertLMHeadModel": ["BertModelTester"],
96
            "BertModel": ["BertModelTester"],
97
        }
98

99
        EXPECTED_BLIP_MAPPING = {
100
            "BlipForConditionalGeneration": ["BlipTextImageModelsModelTester"],
101
            "BlipForImageTextRetrieval": ["BlipTextRetrievalModelTester"],
102
            "BlipForQuestionAnswering": ["BlipVQAModelTester"],
103
            "BlipModel": ["BlipModelTester"],
104
            "BlipTextModel": ["BlipTextModelTester"],
105
            "BlipVisionModel": ["BlipVisionModelTester"],
106
        }
107

108
        self.assertEqual(get_test_info.to_json(bert_model_tester_mapping), EXPECTED_BERT_MAPPING)
109
        self.assertEqual(get_test_info.to_json(blip_model_tester_mapping), EXPECTED_BLIP_MAPPING)
110

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.