4
from allennlp.common.params import Params
5
from allennlp.nn import InitializerApplicator, Initializer
6
from allennlp.nn.regularizers import L1Regularizer, L2Regularizer, RegularizerApplicator
7
from allennlp.common.testing import AllenNlpTestCase
10
class TestRegularizers(AllenNlpTestCase):
11
def test_l1_regularization(self):
12
model = torch.nn.Sequential(torch.nn.Linear(5, 10), torch.nn.Linear(10, 5))
13
constant_init = Initializer.from_params(Params({"type": "constant", "val": -1}))
14
initializer = InitializerApplicator([(".*", constant_init)])
16
value = RegularizerApplicator([("", L1Regularizer(1.0))])(model)
18
assert value.data.numpy() == 115.0
20
def test_l2_regularization(self):
21
model = torch.nn.Sequential(torch.nn.Linear(5, 10), torch.nn.Linear(10, 5))
22
constant_init = Initializer.from_params(Params({"type": "constant", "val": 0.5}))
23
initializer = InitializerApplicator([(".*", constant_init)])
25
value = RegularizerApplicator([("", L2Regularizer(1.0))])(model)
26
assert value.data.numpy() == 28.75
28
def test_regularizer_applicator_respects_regex_matching(self):
29
model = torch.nn.Sequential(torch.nn.Linear(5, 10), torch.nn.Linear(10, 5))
30
constant_init = Initializer.from_params(Params({"type": "constant", "val": 1.0}))
31
initializer = InitializerApplicator([(".*", constant_init)])
33
value = RegularizerApplicator(
34
[("weight", L2Regularizer(0.5)), ("bias", L1Regularizer(1.0))]
36
assert value.data.numpy() == 65.0
38
def test_from_params(self):
39
params = Params({"regexes": [("conv", "l1"), ("linear", {"type": "l2", "alpha": 10})]})
40
regularizer_applicator = RegularizerApplicator.from_params(params)
41
regularizers = regularizer_applicator._regularizers
44
for regex, regularizer in regularizers:
47
elif regex == "linear":
50
assert isinstance(conv, L1Regularizer)
51
assert isinstance(linear, L2Regularizer)
52
assert linear.alpha == 10
54
def test_frozen_params(self):
55
model = torch.nn.Sequential(torch.nn.Linear(5, 10), torch.nn.Linear(10, 5))
56
constant_init = Initializer.from_params(Params({"type": "constant", "val": -1}))
57
initializer = InitializerApplicator([(".*", constant_init)])
60
for name, param in model.named_parameters():
61
if re.search(r"0.*$", name):
62
param.requires_grad = False
63
value = RegularizerApplicator([("", L1Regularizer(1.0))])(model)
65
assert value.data.numpy() == 55