allennlp

Форк
0
/
pretrained_model_initializer_test.py 
204 строки · 8.3 Кб
1
from typing import Dict, Optional
2
import os
3
import tempfile
4
import tarfile
5

6
import pytest
7
import torch
8

9
from allennlp.nn import InitializerApplicator, Initializer
10
from allennlp.nn.initializers import PretrainedModelInitializer
11
from allennlp.common.checks import ConfigurationError
12
from allennlp.common.testing import AllenNlpTestCase
13
from allennlp.common.params import Params
14

15

16
class _Net1(torch.nn.Module):
17
    def __init__(self):
18
        super().__init__()
19
        self.linear_1 = torch.nn.Linear(5, 10)
20
        self.linear_2 = torch.nn.Linear(10, 5)
21
        self.scalar = torch.nn.Parameter(torch.rand(()))
22

23
    def forward(self, inputs):
24
        pass
25

26

27
class _Net2(torch.nn.Module):
28
    def __init__(self):
29
        super().__init__()
30
        self.linear_1 = torch.nn.Linear(5, 10)
31
        self.linear_3 = torch.nn.Linear(10, 5)
32
        self.scalar = torch.nn.Parameter(torch.rand(()))
33

34
    def forward(self, inputs):
35
        pass
36

37

38
class TestPretrainedModelInitializer(AllenNlpTestCase):
39
    def setup_method(self):
40
        super().setup_method()
41
        self.net1 = _Net1()
42
        self.net2 = _Net2()
43
        self.temp_file = self.TEST_DIR / "weights.th"
44
        torch.save(self.net2.state_dict(), self.temp_file)
45

46
    def _are_equal(self, linear1: torch.nn.Linear, linear2: torch.nn.Linear) -> bool:
47
        return torch.equal(linear1.weight, linear2.weight) and torch.equal(
48
            linear1.bias, linear2.bias
49
        )
50

51
    def _get_applicator(
52
        self,
53
        regex: str,
54
        weights_file_path: str,
55
        parameter_name_overrides: Optional[Dict[str, str]] = None,
56
    ) -> InitializerApplicator:
57
        initializer = PretrainedModelInitializer(weights_file_path, parameter_name_overrides)
58
        return InitializerApplicator([(regex, initializer)])
59

60
    def test_random_initialization(self):
61
        # The tests in the class rely on the fact that the parameters for
62
        # `self.net1` and `self.net2` are randomly initialized and not
63
        # equal at the beginning. This test makes sure that's true
64
        assert not self._are_equal(self.net1.linear_1, self.net2.linear_1)
65
        assert not self._are_equal(self.net1.linear_2, self.net2.linear_3)
66

67
    def test_from_params(self):
68
        params = Params({"type": "pretrained", "weights_file_path": self.temp_file})
69
        initializer = Initializer.from_params(params)
70
        assert initializer.weights
71
        assert initializer.parameter_name_overrides == {}
72

73
        name_overrides = {"a": "b", "c": "d"}
74
        params = Params(
75
            {
76
                "type": "pretrained",
77
                "weights_file_path": self.temp_file,
78
                "parameter_name_overrides": name_overrides,
79
            }
80
        )
81
        initializer = Initializer.from_params(params)
82
        assert initializer.weights
83
        assert initializer.parameter_name_overrides == name_overrides
84

85
    def test_from_params_tar_gz(self):
86
        with tempfile.NamedTemporaryFile(suffix=".tar.gz") as f:
87
            with tarfile.open(fileobj=f, mode="w:gz") as archive:
88
                archive.add(self.temp_file, arcname=os.path.basename(self.temp_file))
89
            f.flush()
90
            params = Params({"type": "pretrained", "weights_file_path": f.name})
91
            initializer = Initializer.from_params(params)
92

93
        assert initializer.weights
94
        assert initializer.parameter_name_overrides == {}
95

96
        for name, parameter in self.net2.state_dict().items():
97
            assert torch.equal(parameter, initializer.weights[name])
98

99
    def test_default_parameter_names(self):
100
        # This test initializes net1 to net2's parameters. It doesn't use
101
        # the parameter name overrides, so it will verify the initialization
102
        # works if the two parameters' names are the same.
103
        applicator = self._get_applicator("linear_1.weight|linear_1.bias", self.temp_file)
104
        applicator(self.net1)
105
        assert self._are_equal(self.net1.linear_1, self.net2.linear_1)
106
        assert not self._are_equal(self.net1.linear_2, self.net2.linear_3)
107

108
    def test_parameter_name_overrides(self):
109
        # This test will use the parameter name overrides to initialize all
110
        # of net1's weights to net2's.
111
        name_overrides = {"linear_2.weight": "linear_3.weight", "linear_2.bias": "linear_3.bias"}
112
        applicator = self._get_applicator("linear_*", self.temp_file, name_overrides)
113
        applicator(self.net1)
114
        assert self._are_equal(self.net1.linear_1, self.net2.linear_1)
115
        assert self._are_equal(self.net1.linear_2, self.net2.linear_3)
116

117
    def test_size_mismatch(self):
118
        # This test will verify that an exception is raised when you try
119
        # to initialize a parameter to a pretrained parameter and they have
120
        # different sizes
121
        name_overrides = {"linear_1.weight": "linear_3.weight"}
122
        applicator = self._get_applicator("linear_1.*", self.temp_file, name_overrides)
123
        with pytest.raises(ConfigurationError):
124
            applicator(self.net1)
125

126
    def test_zero_dim_tensor(self):
127
        # This test will verify that a 0-dim tensor can be initialized.
128
        # It raises IndexError if slicing a tensor to copy the parameter.
129
        applicator = self._get_applicator("scalar", self.temp_file)
130
        applicator(self.net1)
131
        assert torch.equal(self.net1.scalar, self.net2.scalar)
132

133
    @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device registered.")
134
    def test_load_to_gpu_from_gpu(self):
135
        # This test will make sure that the initializer works on the GPU
136
        self.net1.cuda(device=0)
137
        self.net2.cuda(device=0)
138

139
        # Verify the parameters are on the GPU
140
        assert self.net1.linear_1.weight.is_cuda is True
141
        assert self.net1.linear_1.bias.is_cuda is True
142
        assert self.net2.linear_1.weight.is_cuda is True
143
        assert self.net2.linear_1.bias.is_cuda is True
144

145
        # We need to manually save the parameters to a file because setup_method()
146
        # only does it for the CPU
147
        temp_file = self.TEST_DIR / "gpu_weights.th"
148
        torch.save(self.net2.state_dict(), temp_file)
149

150
        applicator = self._get_applicator("linear_1.*", temp_file)
151
        applicator(self.net1)
152

153
        # Verify the parameters are still on the GPU
154
        assert self.net1.linear_1.weight.is_cuda is True
155
        assert self.net1.linear_1.bias.is_cuda is True
156
        assert self.net2.linear_1.weight.is_cuda is True
157
        assert self.net2.linear_1.bias.is_cuda is True
158

159
        # Make sure the weights are identical
160
        assert self._are_equal(self.net1.linear_1, self.net2.linear_1)
161

162
    @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device registered.")
163
    def test_load_to_cpu_from_gpu(self):
164
        # This test will load net2's parameters onto the GPU, then use them to
165
        # initialize net1 on the CPU
166
        self.net2.cuda(device=0)
167

168
        # Verify the parameters are on the GPU
169
        assert self.net2.linear_1.weight.is_cuda is True
170
        assert self.net2.linear_1.bias.is_cuda is True
171

172
        temp_file = self.TEST_DIR / "gpu_weights.th"
173
        torch.save(self.net2.state_dict(), temp_file)
174

175
        applicator = self._get_applicator("linear_1.*", temp_file)
176
        applicator(self.net1)
177

178
        # Verify the parameters are on the CPU
179
        assert self.net1.linear_1.weight.is_cuda is False
180
        assert self.net1.linear_1.bias.is_cuda is False
181

182
        # Make sure the weights are identical
183
        assert self._are_equal(self.net1.linear_1, self.net2.linear_1.cpu())
184

185
    @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device registered.")
186
    def test_load_to_gpu_from_cpu(self):
187
        # This test will load net1's parameters onto the GPU, then use net2's
188
        # on the CPU to initialize net1's parameters.
189
        self.net1.cuda(device=0)
190

191
        # Verify the parameters are on the GPU
192
        assert self.net1.linear_1.weight.is_cuda is True
193
        assert self.net1.linear_1.bias.is_cuda is True
194

195
        # net2's parameters are already saved to CPU from setup_method()
196
        applicator = self._get_applicator("linear_1.*", self.temp_file)
197
        applicator(self.net1)
198

199
        # Verify the parameters are on the GPU
200
        assert self.net1.linear_1.weight.is_cuda is True
201
        assert self.net1.linear_1.bias.is_cuda is True
202

203
        # Make sure the weights are identical
204
        assert self._are_equal(self.net1.linear_1.cpu(), self.net2.linear_1)
205

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.