allennlp

Форк
0
/
diff_test.py 
175 строк · 5.7 Кб
1
import sys
2

3
import torch
4
from torch.nn import Parameter
5

6
from allennlp.commands import main
7
from allennlp.common.testing import AllenNlpTestCase
8

9

10
def _clean_output(output: str) -> str:
11
    # Removes color characters.
12
    return (
13
        output.replace("\x1b[0m", "")
14
        .replace("\x1b[31m", "")
15
        .replace("\x1b[32m", "")
16
        .replace("\x1b[33m", "")
17
        .strip()
18
    )
19

20

21
class TestDiffCommand(AllenNlpTestCase):
22
    def test_from_archive(self, capsys):
23
        archive_path = str(
24
            self.FIXTURES_ROOT / "basic_classifier" / "serialization" / "model.tar.gz"
25
        )
26
        sys.argv = ["allennlp", "diff", archive_path, archive_path]
27
        main()
28
        captured = capsys.readouterr()
29
        assert (
30
            _clean_output(captured.out)
31
            == """
32
 _text_field_embedder.token_embedder_tokens.weight, shape = (213, 10)
33
 _seq2seq_encoder._module.weight_ih_l0, shape = (64, 10)
34
 _seq2seq_encoder._module.weight_hh_l0, shape = (64, 16)
35
 _seq2seq_encoder._module.bias_ih_l0, shape = (64,)
36
 _seq2seq_encoder._module.bias_hh_l0, shape = (64,)
37
 _feedforward._linear_layers.0.weight, shape = (20, 16)
38
 _feedforward._linear_layers.0.bias, shape = (20,)
39
 _classification_layer.weight, shape = (2, 20)
40
 _classification_layer.bias, shape = (2,)
41
        """.strip()
42
        )
43

44
    def test_from_huggingface(self, capsys):
45
        model_id = "hf://epwalsh/bert-xsmall-dummy/pytorch_model.bin"
46
        sys.argv = [
47
            "allennlp",
48
            "diff",
49
            model_id,
50
            model_id,
51
        ]
52
        main()
53
        captured = capsys.readouterr()
54
        assert (
55
            _clean_output(captured.out)
56
            == """
57
 embeddings.word_embeddings.weight, shape = (250, 20)
58
 embeddings.position_embeddings.weight, shape = (512, 20)
59
 embeddings.token_type_embeddings.weight, shape = (2, 20)
60
 embeddings.LayerNorm.weight, shape = (20,)
61
 embeddings.LayerNorm.bias, shape = (20,)
62
 encoder.layer.0.attention.self.query.weight, shape = (20, 20)
63
 encoder.layer.0.attention.self.query.bias, shape = (20,)
64
 encoder.layer.0.attention.self.key.weight, shape = (20, 20)
65
 encoder.layer.0.attention.self.key.bias, shape = (20,)
66
 encoder.layer.0.attention.self.value.weight, shape = (20, 20)
67
 encoder.layer.0.attention.self.value.bias, shape = (20,)
68
 encoder.layer.0.attention.output.dense.weight, shape = (20, 20)
69
 encoder.layer.0.attention.output.dense.bias, shape = (20,)
70
 encoder.layer.0.attention.output.LayerNorm.weight, shape = (20,)
71
 encoder.layer.0.attention.output.LayerNorm.bias, shape = (20,)
72
 encoder.layer.0.intermediate.dense.weight, shape = (40, 20)
73
 encoder.layer.0.intermediate.dense.bias, shape = (40,)
74
 encoder.layer.0.output.dense.weight, shape = (20, 40)
75
 encoder.layer.0.output.dense.bias, shape = (20,)
76
 encoder.layer.0.output.LayerNorm.weight, shape = (20,)
77
 encoder.layer.0.output.LayerNorm.bias, shape = (20,)
78
 pooler.dense.weight, shape = (20, 20)
79
 pooler.dense.bias, shape = (20,)
80
        """.strip()
81
        )
82

83
    def test_diff_correct(self, capsys):
84
        class ModelA(torch.nn.Module):
85
            def __init__(self):
86
                super().__init__()
87
                self.a = Parameter(torch.tensor([1.0, 0.0, 0.0]))
88
                self.b = Parameter(torch.tensor([1.0, 0.0, 0.0]))
89
                self.c = Parameter(torch.tensor([1.0, 0.0, 0.0]))
90
                self.e = Parameter(torch.tensor([1.0, 0.0, 0.0]))
91

92
        class ModelB(torch.nn.Module):
93
            def __init__(self):
94
                super().__init__()
95
                self.a = Parameter(torch.tensor([1.0, 0.0, 0.0]))
96
                self.b = Parameter(torch.tensor([1.0, 0.0, 0.0, 0.0]))
97
                self.d = Parameter(torch.tensor([1.0, 0.0, 0.0]))
98
                self.e = Parameter(torch.tensor([1.0, 0.0, 1.0]))
99

100
        model_a = ModelA()
101
        model_b = ModelB()
102

103
        torch.save(model_a.state_dict(), self.TEST_DIR / "checkpoint_a.pt")
104
        torch.save(model_b.state_dict(), self.TEST_DIR / "checkpoint_b.pt")
105
        sys.argv = [
106
            "allennlp",
107
            "diff",
108
            str(self.TEST_DIR / "checkpoint_a.pt"),
109
            str(self.TEST_DIR / "checkpoint_b.pt"),
110
        ]
111
        main()
112
        captured = capsys.readouterr()
113
        assert (
114
            _clean_output(captured.out)
115
            == """
116
 a, shape = (3,)
117
-b, shape = (3,)
118
-c, shape = (3,)
119
+b, shape = (4,)
120
+d, shape = (3,)
121
!e, shape = (3,), distance = 0.5774
122
        """.strip()
123
        )
124
        # NOTE: the difference value here of for 'e' of 0.5774 is currently
125
        # calculated at the square root of the mean squared difference between 'e'
126
        # in 'model_a' and 'e' in 'model_b':
127
        # sqrt( (0^2 + 0^2 + 1^2) / 3 ) = sqrt( 1/3 ) = 0.5774
128

129
        # Now call again with a higher theshold.
130
        sys.argv = [
131
            "allennlp",
132
            "diff",
133
            str(self.TEST_DIR / "checkpoint_a.pt"),
134
            str(self.TEST_DIR / "checkpoint_b.pt"),
135
            "--threshold",
136
            "0.6",
137
        ]
138
        main()
139
        captured = capsys.readouterr()
140
        assert (
141
            _clean_output(captured.out)
142
            == """
143
 a, shape = (3,)
144
-b, shape = (3,)
145
-c, shape = (3,)
146
+b, shape = (4,)
147
+d, shape = (3,)
148
 e, shape = (3,)
149
        """.strip()
150
        )
151

152
        # And call a third time with the same threshold but a higher scale.
153
        sys.argv = [
154
            "allennlp",
155
            "diff",
156
            str(self.TEST_DIR / "checkpoint_a.pt"),
157
            str(self.TEST_DIR / "checkpoint_b.pt"),
158
            "--threshold",
159
            "0.6",
160
            "--scale",
161
            "10.0",
162
        ]
163
        main()
164
        captured = capsys.readouterr()
165
        assert (
166
            _clean_output(captured.out)
167
            == """
168
 a, shape = (3,)
169
-b, shape = (3,)
170
-c, shape = (3,)
171
+b, shape = (4,)
172
+d, shape = (3,)
173
!e, shape = (3,), distance = 5.7735
174
        """.strip()
175
        )
176

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.