allennlp

Форк
0
/
regularizers_test.py 
65 строк · 2.9 Кб
1
import re
2
import torch
3

4
from allennlp.common.params import Params
5
from allennlp.nn import InitializerApplicator, Initializer
6
from allennlp.nn.regularizers import L1Regularizer, L2Regularizer, RegularizerApplicator
7
from allennlp.common.testing import AllenNlpTestCase
8

9

10
class TestRegularizers(AllenNlpTestCase):
11
    def test_l1_regularization(self):
12
        model = torch.nn.Sequential(torch.nn.Linear(5, 10), torch.nn.Linear(10, 5))
13
        constant_init = Initializer.from_params(Params({"type": "constant", "val": -1}))
14
        initializer = InitializerApplicator([(".*", constant_init)])
15
        initializer(model)
16
        value = RegularizerApplicator([("", L1Regularizer(1.0))])(model)
17
        # 115 because of biases.
18
        assert value.data.numpy() == 115.0
19

20
    def test_l2_regularization(self):
21
        model = torch.nn.Sequential(torch.nn.Linear(5, 10), torch.nn.Linear(10, 5))
22
        constant_init = Initializer.from_params(Params({"type": "constant", "val": 0.5}))
23
        initializer = InitializerApplicator([(".*", constant_init)])
24
        initializer(model)
25
        value = RegularizerApplicator([("", L2Regularizer(1.0))])(model)
26
        assert value.data.numpy() == 28.75
27

28
    def test_regularizer_applicator_respects_regex_matching(self):
29
        model = torch.nn.Sequential(torch.nn.Linear(5, 10), torch.nn.Linear(10, 5))
30
        constant_init = Initializer.from_params(Params({"type": "constant", "val": 1.0}))
31
        initializer = InitializerApplicator([(".*", constant_init)])
32
        initializer(model)
33
        value = RegularizerApplicator(
34
            [("weight", L2Regularizer(0.5)), ("bias", L1Regularizer(1.0))]
35
        )(model)
36
        assert value.data.numpy() == 65.0
37

38
    def test_from_params(self):
39
        params = Params({"regexes": [("conv", "l1"), ("linear", {"type": "l2", "alpha": 10})]})
40
        regularizer_applicator = RegularizerApplicator.from_params(params)
41
        regularizers = regularizer_applicator._regularizers
42

43
        conv = linear = None
44
        for regex, regularizer in regularizers:
45
            if regex == "conv":
46
                conv = regularizer
47
            elif regex == "linear":
48
                linear = regularizer
49

50
        assert isinstance(conv, L1Regularizer)
51
        assert isinstance(linear, L2Regularizer)
52
        assert linear.alpha == 10
53

54
    def test_frozen_params(self):
55
        model = torch.nn.Sequential(torch.nn.Linear(5, 10), torch.nn.Linear(10, 5))
56
        constant_init = Initializer.from_params(Params({"type": "constant", "val": -1}))
57
        initializer = InitializerApplicator([(".*", constant_init)])
58
        initializer(model)
59
        # freeze the parameters of the first linear
60
        for name, param in model.named_parameters():
61
            if re.search(r"0.*$", name):
62
                param.requires_grad = False
63
        value = RegularizerApplicator([("", L1Regularizer(1.0))])(model)
64
        # 55 because of bias (5*10 + 5)
65
        assert value.data.numpy() == 55
66

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.