allennlp

Форк
0
/
initializers_test.py 
132 строки · 4.8 Кб
1
import json
2
import logging
3
import math
4

5
import numpy
6
import pytest
7
import torch
8
import _jsonnet
9

10
from allennlp.nn import InitializerApplicator, Initializer
11
from allennlp.nn.initializers import block_orthogonal, uniform_unit_scaling
12
from allennlp.common.checks import ConfigurationError
13
from allennlp.common.testing import AllenNlpTestCase
14
from allennlp.common.params import Params
15

16

17
class TestInitializers(AllenNlpTestCase):
18
    def setup_method(self):
19
        super().setup_method()
20
        logging.getLogger("allennlp.nn.initializers").disabled = False
21

22
    def tearDown(self):
23
        super().tearDown()
24
        logging.getLogger("allennlp.nn.initializers").disabled = True
25

26
    def test_from_params_string(self):
27
        Initializer.from_params(params="eye")
28

29
    def test_from_params_none(self):
30
        Initializer.from_params(params=None)
31

32
    def test_regex_matches_are_initialized_correctly(self):
33
        class Net(torch.nn.Module):
34
            def __init__(self):
35
                super().__init__()
36
                self.linear_1_with_funky_name = torch.nn.Linear(5, 10)
37
                self.linear_2 = torch.nn.Linear(10, 5)
38
                self.conv = torch.nn.Conv1d(5, 5, 5)
39

40
            def forward(self, inputs):
41
                pass
42

43
        # Make sure we handle regexes properly
44
        json_params = """{"initializer": {"regexes": [
45
        ["conv", {"type": "constant", "val": 5}],
46
        ["funky_na.*bi", {"type": "constant", "val": 7}]
47
        ]}}
48
        """
49
        params = Params(json.loads(_jsonnet.evaluate_snippet("", json_params)))
50
        initializers = InitializerApplicator.from_params(params=params["initializer"])
51
        model = Net()
52
        initializers(model)
53

54
        for parameter in model.conv.parameters():
55
            assert torch.equal(parameter.data, torch.ones(parameter.size()) * 5)
56

57
        parameter = model.linear_1_with_funky_name.bias
58
        assert torch.equal(parameter.data, torch.ones(parameter.size()) * 7)
59

60
    def test_block_orthogonal_can_initialize(self):
61
        tensor = torch.zeros([10, 6])
62
        block_orthogonal(tensor, [5, 3])
63
        tensor = tensor.data.numpy()
64

65
        def test_block_is_orthogonal(block) -> None:
66
            matrix_product = block.T @ block
67
            numpy.testing.assert_array_almost_equal(
68
                matrix_product, numpy.eye(matrix_product.shape[-1]), 6
69
            )
70

71
        test_block_is_orthogonal(tensor[:5, :3])
72
        test_block_is_orthogonal(tensor[:5, 3:])
73
        test_block_is_orthogonal(tensor[5:, 3:])
74
        test_block_is_orthogonal(tensor[5:, :3])
75

76
    def test_block_orthogonal_raises_on_mismatching_dimensions(self):
77
        tensor = torch.zeros([10, 6, 8])
78
        with pytest.raises(ConfigurationError):
79
            block_orthogonal(tensor, [7, 2, 1])
80

81
    def test_uniform_unit_scaling_can_initialize(self):
82
        tensor = torch.zeros([10, 6])
83
        uniform_unit_scaling(tensor, "linear")
84

85
        assert tensor.data.max() < math.sqrt(3 / 10)
86
        assert tensor.data.min() > -math.sqrt(3 / 10)
87

88
        # Check that it gets the scaling correct for relu (1.43).
89
        uniform_unit_scaling(tensor, "relu")
90
        assert tensor.data.max() < math.sqrt(3 / 10) * 1.43
91
        assert tensor.data.min() > -math.sqrt(3 / 10) * 1.43
92

93
    def test_regex_match_prevention_prevents_and_overrides(self):
94
        class Net(torch.nn.Module):
95
            def __init__(self):
96
                super().__init__()
97
                self.linear_1 = torch.nn.Linear(5, 10)
98
                self.linear_2 = torch.nn.Linear(10, 5)
99
                # typical actual usage: modules loaded from allenlp.model.load(..)
100
                self.linear_3_transfer = torch.nn.Linear(5, 10)
101
                self.linear_4_transfer = torch.nn.Linear(10, 5)
102
                self.pretrained_conv = torch.nn.Conv1d(5, 5, 5)
103

104
            def forward(self, inputs):
105
                pass
106

107
        json_params = """{"initializer": {
108
        "regexes": [
109
            [".*linear.*", {"type": "constant", "val": 10}],
110
            [".*conv.*", {"type": "constant", "val": 10}]
111
            ],
112
        "prevent_regexes": [".*_transfer.*", ".*pretrained.*"]
113
        }}
114
        """
115
        params = Params(json.loads(_jsonnet.evaluate_snippet("", json_params)))
116
        initializers = InitializerApplicator.from_params(params=params["initializer"])
117
        model = Net()
118
        initializers(model)
119

120
        for module in [model.linear_1, model.linear_2]:
121
            for parameter in module.parameters():
122
                assert torch.equal(parameter.data, torch.ones(parameter.size()) * 10)
123

124
        transfered_modules = [
125
            model.linear_3_transfer,
126
            model.linear_4_transfer,
127
            model.pretrained_conv,
128
        ]
129

130
        for module in transfered_modules:
131
            for parameter in module.parameters():
132
                assert not torch.equal(parameter.data, torch.ones(parameter.size()) * 10)
133

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.