allennlp

Форк
0
/
model_card_test.py 
157 строк · 4.8 Кб
1
from allennlp.common.params import Params
2
from allennlp.common.testing import AllenNlpTestCase
3
from allennlp.common.model_card import ModelCard, ModelUsage, IntendedUse, Paper
4
from allennlp.models import Model
5

6

7
class TestPretrainedModelConfiguration(AllenNlpTestCase):
8
    def test_init(self):
9
        model_card = ModelCard(
10
            id="fake_name",
11
            display_name="Fake Name",
12
            model_details="Model's description",
13
            model_usage=ModelUsage(**{"archive_file": "fake.tar.gz", "overrides": {}}),
14
        )
15

16
        assert model_card.id == "fake_name"
17
        assert model_card.display_name == "Fake Name"
18
        assert model_card.model_usage.archive_file == ModelUsage._storage_location + "fake.tar.gz"
19
        assert model_card.model_details.description == "Model's description"
20

21
    def test_init_registered_model(self):
22
        @Model.register("fake-model")
23
        class FakeModel(Model):
24
            """
25
            This is a fake model with a docstring.
26

27
            # Parameters
28

29
            fake_param1: str
30
            fake_param2: int
31
            """
32

33
            def forward(self, **kwargs):
34
                return {}
35

36
        model_card = ModelCard(**{"id": "this-fake-model", "registered_model_name": "fake-model"})
37

38
        assert model_card.display_name == "FakeModel"
39
        assert model_card.model_details.description == "This is a fake model with a docstring."
40

41
    def test_init_dict_model(self):
42
        class FakeModel(Model):
43
            """
44
            This is a fake model with a docstring.
45

46
            # Parameters
47

48
            fake_param1: str
49
            fake_param2: int
50
            """
51

52
            def forward(self, **kwargs):
53
                return {}
54

55
        model_card = ModelCard(**{"id": "this-fake-model", "model_class": FakeModel})
56

57
        assert model_card.display_name == "FakeModel"
58
        assert model_card.model_details.description == "This is a fake model with a docstring."
59

60
    def test_init_registered_model_override(self):
61
        @Model.register("fake-model-2")
62
        class FakeModel(Model):
63
            """
64
            This is a fake model with a docstring.
65

66
            # Parameters
67

68
            fake_param1: str
69
            fake_param2: int
70
            """
71

72
            def forward(self, **kwargs):
73
                return {}
74

75
        model_card = ModelCard(
76
            **{
77
                "id": "this-fake-model",
78
                "registered_model_name": "fake-model-2",
79
                "model_details": "This is the fake model trained on a dataset.",
80
                "model_class": FakeModel,
81
            }
82
        )
83

84
        assert (
85
            model_card.model_details.description == "This is the fake model trained on a dataset."
86
        )
87

88
    def test_init_model_card_info_obj(self):
89
        @Model.register("fake-model-3")
90
        class FakeModel(Model):
91
            """
92
            This is a fake model with a docstring.
93

94
            # Parameters
95

96
            fake_param1: str
97
            fake_param2: int
98
            """
99

100
            def forward(self, **kwargs):
101
                return {}
102

103
        intended_use = IntendedUse("Use 1", "User 1")
104

105
        model_card = ModelCard(
106
            **{
107
                "id": "this-fake-model",
108
                "registered_model_name": "fake-model-3",
109
                "intended_use": intended_use,
110
            }
111
        )
112

113
        model_card_dict = model_card.to_dict()
114
        assert model_card.display_name == "FakeModel"
115

116
        for key, val in intended_use.__dict__.items():
117
            if val:
118
                assert key in model_card_dict
119
            else:
120
                assert key not in model_card_dict
121

122
    def test_nested_json(self):
123
        @Model.register("fake-model-4")
124
        class FakeModel(Model):
125
            """
126
            This is a fake model with a docstring.
127

128
            # Parameters
129

130
            fake_param1: str
131
            fake_param2: int
132
            """
133

134
            def forward(self, **kwargs):
135
                return {}
136

137
        model_card = ModelCard.from_params(
138
            Params(
139
                {
140
                    "id": "this-fake-model",
141
                    "registered_model_name": "fake-model-4",
142
                    "model_details": {
143
                        "description": "This is the fake model trained on a dataset.",
144
                        "paper": {
145
                            "title": "paper name",
146
                            "url": "paper link",
147
                            "citation": "test citation",
148
                        },
149
                    },
150
                    "training_data": {"dataset": {"name": "dataset 1", "url": "dataset url"}},
151
                }
152
            )
153
        )
154

155
        assert isinstance(model_card.model_details.paper, Paper)
156
        assert model_card.model_details.paper.url == "paper link"
157
        assert model_card.training_data.dataset.name == "dataset 1"
158

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.