10
from allennlp.nn import InitializerApplicator, Initializer
11
from allennlp.nn.initializers import block_orthogonal, uniform_unit_scaling
12
from allennlp.common.checks import ConfigurationError
13
from allennlp.common.testing import AllenNlpTestCase
14
from allennlp.common.params import Params
17
class TestInitializers(AllenNlpTestCase):
18
def setup_method(self):
19
super().setup_method()
20
logging.getLogger("allennlp.nn.initializers").disabled = False
24
logging.getLogger("allennlp.nn.initializers").disabled = True
26
def test_from_params_string(self):
27
Initializer.from_params(params="eye")
29
def test_from_params_none(self):
30
Initializer.from_params(params=None)
32
def test_regex_matches_are_initialized_correctly(self):
33
class Net(torch.nn.Module):
36
self.linear_1_with_funky_name = torch.nn.Linear(5, 10)
37
self.linear_2 = torch.nn.Linear(10, 5)
38
self.conv = torch.nn.Conv1d(5, 5, 5)
40
def forward(self, inputs):
44
json_params = """{"initializer": {"regexes": [
45
["conv", {"type": "constant", "val": 5}],
46
["funky_na.*bi", {"type": "constant", "val": 7}]
49
params = Params(json.loads(_jsonnet.evaluate_snippet("", json_params)))
50
initializers = InitializerApplicator.from_params(params=params["initializer"])
54
for parameter in model.conv.parameters():
55
assert torch.equal(parameter.data, torch.ones(parameter.size()) * 5)
57
parameter = model.linear_1_with_funky_name.bias
58
assert torch.equal(parameter.data, torch.ones(parameter.size()) * 7)
60
def test_block_orthogonal_can_initialize(self):
61
tensor = torch.zeros([10, 6])
62
block_orthogonal(tensor, [5, 3])
63
tensor = tensor.data.numpy()
65
def test_block_is_orthogonal(block) -> None:
66
matrix_product = block.T @ block
67
numpy.testing.assert_array_almost_equal(
68
matrix_product, numpy.eye(matrix_product.shape[-1]), 6
71
test_block_is_orthogonal(tensor[:5, :3])
72
test_block_is_orthogonal(tensor[:5, 3:])
73
test_block_is_orthogonal(tensor[5:, 3:])
74
test_block_is_orthogonal(tensor[5:, :3])
76
def test_block_orthogonal_raises_on_mismatching_dimensions(self):
77
tensor = torch.zeros([10, 6, 8])
78
with pytest.raises(ConfigurationError):
79
block_orthogonal(tensor, [7, 2, 1])
81
def test_uniform_unit_scaling_can_initialize(self):
82
tensor = torch.zeros([10, 6])
83
uniform_unit_scaling(tensor, "linear")
85
assert tensor.data.max() < math.sqrt(3 / 10)
86
assert tensor.data.min() > -math.sqrt(3 / 10)
89
uniform_unit_scaling(tensor, "relu")
90
assert tensor.data.max() < math.sqrt(3 / 10) * 1.43
91
assert tensor.data.min() > -math.sqrt(3 / 10) * 1.43
93
def test_regex_match_prevention_prevents_and_overrides(self):
94
class Net(torch.nn.Module):
97
self.linear_1 = torch.nn.Linear(5, 10)
98
self.linear_2 = torch.nn.Linear(10, 5)
100
self.linear_3_transfer = torch.nn.Linear(5, 10)
101
self.linear_4_transfer = torch.nn.Linear(10, 5)
102
self.pretrained_conv = torch.nn.Conv1d(5, 5, 5)
104
def forward(self, inputs):
107
json_params = """{"initializer": {
109
[".*linear.*", {"type": "constant", "val": 10}],
110
[".*conv.*", {"type": "constant", "val": 10}]
112
"prevent_regexes": [".*_transfer.*", ".*pretrained.*"]
115
params = Params(json.loads(_jsonnet.evaluate_snippet("", json_params)))
116
initializers = InitializerApplicator.from_params(params=params["initializer"])
120
for module in [model.linear_1, model.linear_2]:
121
for parameter in module.parameters():
122
assert torch.equal(parameter.data, torch.ones(parameter.size()) * 10)
124
transfered_modules = [
125
model.linear_3_transfer,
126
model.linear_4_transfer,
127
model.pretrained_conv,
130
for module in transfered_modules:
131
for parameter in module.parameters():
132
assert not torch.equal(parameter.data, torch.ones(parameter.size()) * 10)