1
from typing import Dict, Optional
9
from allennlp.nn import InitializerApplicator, Initializer
10
from allennlp.nn.initializers import PretrainedModelInitializer
11
from allennlp.common.checks import ConfigurationError
12
from allennlp.common.testing import AllenNlpTestCase
13
from allennlp.common.params import Params
16
class _Net1(torch.nn.Module):
19
self.linear_1 = torch.nn.Linear(5, 10)
20
self.linear_2 = torch.nn.Linear(10, 5)
21
self.scalar = torch.nn.Parameter(torch.rand(()))
23
def forward(self, inputs):
27
class _Net2(torch.nn.Module):
30
self.linear_1 = torch.nn.Linear(5, 10)
31
self.linear_3 = torch.nn.Linear(10, 5)
32
self.scalar = torch.nn.Parameter(torch.rand(()))
34
def forward(self, inputs):
38
class TestPretrainedModelInitializer(AllenNlpTestCase):
39
def setup_method(self):
40
super().setup_method()
43
self.temp_file = self.TEST_DIR / "weights.th"
44
torch.save(self.net2.state_dict(), self.temp_file)
46
def _are_equal(self, linear1: torch.nn.Linear, linear2: torch.nn.Linear) -> bool:
47
return torch.equal(linear1.weight, linear2.weight) and torch.equal(
48
linear1.bias, linear2.bias
54
weights_file_path: str,
55
parameter_name_overrides: Optional[Dict[str, str]] = None,
56
) -> InitializerApplicator:
57
initializer = PretrainedModelInitializer(weights_file_path, parameter_name_overrides)
58
return InitializerApplicator([(regex, initializer)])
60
def test_random_initialization(self):
64
assert not self._are_equal(self.net1.linear_1, self.net2.linear_1)
65
assert not self._are_equal(self.net1.linear_2, self.net2.linear_3)
67
def test_from_params(self):
68
params = Params({"type": "pretrained", "weights_file_path": self.temp_file})
69
initializer = Initializer.from_params(params)
70
assert initializer.weights
71
assert initializer.parameter_name_overrides == {}
73
name_overrides = {"a": "b", "c": "d"}
77
"weights_file_path": self.temp_file,
78
"parameter_name_overrides": name_overrides,
81
initializer = Initializer.from_params(params)
82
assert initializer.weights
83
assert initializer.parameter_name_overrides == name_overrides
85
def test_from_params_tar_gz(self):
86
with tempfile.NamedTemporaryFile(suffix=".tar.gz") as f:
87
with tarfile.open(fileobj=f, mode="w:gz") as archive:
88
archive.add(self.temp_file, arcname=os.path.basename(self.temp_file))
90
params = Params({"type": "pretrained", "weights_file_path": f.name})
91
initializer = Initializer.from_params(params)
93
assert initializer.weights
94
assert initializer.parameter_name_overrides == {}
96
for name, parameter in self.net2.state_dict().items():
97
assert torch.equal(parameter, initializer.weights[name])
99
def test_default_parameter_names(self):
103
applicator = self._get_applicator("linear_1.weight|linear_1.bias", self.temp_file)
104
applicator(self.net1)
105
assert self._are_equal(self.net1.linear_1, self.net2.linear_1)
106
assert not self._are_equal(self.net1.linear_2, self.net2.linear_3)
108
def test_parameter_name_overrides(self):
111
name_overrides = {"linear_2.weight": "linear_3.weight", "linear_2.bias": "linear_3.bias"}
112
applicator = self._get_applicator("linear_*", self.temp_file, name_overrides)
113
applicator(self.net1)
114
assert self._are_equal(self.net1.linear_1, self.net2.linear_1)
115
assert self._are_equal(self.net1.linear_2, self.net2.linear_3)
117
def test_size_mismatch(self):
121
name_overrides = {"linear_1.weight": "linear_3.weight"}
122
applicator = self._get_applicator("linear_1.*", self.temp_file, name_overrides)
123
with pytest.raises(ConfigurationError):
124
applicator(self.net1)
126
def test_zero_dim_tensor(self):
129
applicator = self._get_applicator("scalar", self.temp_file)
130
applicator(self.net1)
131
assert torch.equal(self.net1.scalar, self.net2.scalar)
133
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device registered.")
134
def test_load_to_gpu_from_gpu(self):
136
self.net1.cuda(device=0)
137
self.net2.cuda(device=0)
140
assert self.net1.linear_1.weight.is_cuda is True
141
assert self.net1.linear_1.bias.is_cuda is True
142
assert self.net2.linear_1.weight.is_cuda is True
143
assert self.net2.linear_1.bias.is_cuda is True
147
temp_file = self.TEST_DIR / "gpu_weights.th"
148
torch.save(self.net2.state_dict(), temp_file)
150
applicator = self._get_applicator("linear_1.*", temp_file)
151
applicator(self.net1)
154
assert self.net1.linear_1.weight.is_cuda is True
155
assert self.net1.linear_1.bias.is_cuda is True
156
assert self.net2.linear_1.weight.is_cuda is True
157
assert self.net2.linear_1.bias.is_cuda is True
160
assert self._are_equal(self.net1.linear_1, self.net2.linear_1)
162
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device registered.")
163
def test_load_to_cpu_from_gpu(self):
166
self.net2.cuda(device=0)
169
assert self.net2.linear_1.weight.is_cuda is True
170
assert self.net2.linear_1.bias.is_cuda is True
172
temp_file = self.TEST_DIR / "gpu_weights.th"
173
torch.save(self.net2.state_dict(), temp_file)
175
applicator = self._get_applicator("linear_1.*", temp_file)
176
applicator(self.net1)
179
assert self.net1.linear_1.weight.is_cuda is False
180
assert self.net1.linear_1.bias.is_cuda is False
183
assert self._are_equal(self.net1.linear_1, self.net2.linear_1.cpu())
185
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device registered.")
186
def test_load_to_gpu_from_cpu(self):
189
self.net1.cuda(device=0)
192
assert self.net1.linear_1.weight.is_cuda is True
193
assert self.net1.linear_1.bias.is_cuda is True
196
applicator = self._get_applicator("linear_1.*", self.temp_file)
197
applicator(self.net1)
200
assert self.net1.linear_1.weight.is_cuda is True
201
assert self.net1.linear_1.bias.is_cuda is True
204
assert self._are_equal(self.net1.linear_1.cpu(), self.net2.linear_1)