4
from torch.nn import Parameter
6
from allennlp.commands import main
7
from allennlp.common.testing import AllenNlpTestCase
10
def _clean_output(output: str) -> str:
11
# Removes color characters.
13
output.replace("\x1b[0m", "")
14
.replace("\x1b[31m", "")
15
.replace("\x1b[32m", "")
16
.replace("\x1b[33m", "")
21
class TestDiffCommand(AllenNlpTestCase):
22
def test_from_archive(self, capsys):
24
self.FIXTURES_ROOT / "basic_classifier" / "serialization" / "model.tar.gz"
26
sys.argv = ["allennlp", "diff", archive_path, archive_path]
28
captured = capsys.readouterr()
30
_clean_output(captured.out)
32
_text_field_embedder.token_embedder_tokens.weight, shape = (213, 10)
33
_seq2seq_encoder._module.weight_ih_l0, shape = (64, 10)
34
_seq2seq_encoder._module.weight_hh_l0, shape = (64, 16)
35
_seq2seq_encoder._module.bias_ih_l0, shape = (64,)
36
_seq2seq_encoder._module.bias_hh_l0, shape = (64,)
37
_feedforward._linear_layers.0.weight, shape = (20, 16)
38
_feedforward._linear_layers.0.bias, shape = (20,)
39
_classification_layer.weight, shape = (2, 20)
40
_classification_layer.bias, shape = (2,)
44
def test_from_huggingface(self, capsys):
45
model_id = "hf://epwalsh/bert-xsmall-dummy/pytorch_model.bin"
53
captured = capsys.readouterr()
55
_clean_output(captured.out)
57
embeddings.word_embeddings.weight, shape = (250, 20)
58
embeddings.position_embeddings.weight, shape = (512, 20)
59
embeddings.token_type_embeddings.weight, shape = (2, 20)
60
embeddings.LayerNorm.weight, shape = (20,)
61
embeddings.LayerNorm.bias, shape = (20,)
62
encoder.layer.0.attention.self.query.weight, shape = (20, 20)
63
encoder.layer.0.attention.self.query.bias, shape = (20,)
64
encoder.layer.0.attention.self.key.weight, shape = (20, 20)
65
encoder.layer.0.attention.self.key.bias, shape = (20,)
66
encoder.layer.0.attention.self.value.weight, shape = (20, 20)
67
encoder.layer.0.attention.self.value.bias, shape = (20,)
68
encoder.layer.0.attention.output.dense.weight, shape = (20, 20)
69
encoder.layer.0.attention.output.dense.bias, shape = (20,)
70
encoder.layer.0.attention.output.LayerNorm.weight, shape = (20,)
71
encoder.layer.0.attention.output.LayerNorm.bias, shape = (20,)
72
encoder.layer.0.intermediate.dense.weight, shape = (40, 20)
73
encoder.layer.0.intermediate.dense.bias, shape = (40,)
74
encoder.layer.0.output.dense.weight, shape = (20, 40)
75
encoder.layer.0.output.dense.bias, shape = (20,)
76
encoder.layer.0.output.LayerNorm.weight, shape = (20,)
77
encoder.layer.0.output.LayerNorm.bias, shape = (20,)
78
pooler.dense.weight, shape = (20, 20)
79
pooler.dense.bias, shape = (20,)
83
def test_diff_correct(self, capsys):
84
class ModelA(torch.nn.Module):
87
self.a = Parameter(torch.tensor([1.0, 0.0, 0.0]))
88
self.b = Parameter(torch.tensor([1.0, 0.0, 0.0]))
89
self.c = Parameter(torch.tensor([1.0, 0.0, 0.0]))
90
self.e = Parameter(torch.tensor([1.0, 0.0, 0.0]))
92
class ModelB(torch.nn.Module):
95
self.a = Parameter(torch.tensor([1.0, 0.0, 0.0]))
96
self.b = Parameter(torch.tensor([1.0, 0.0, 0.0, 0.0]))
97
self.d = Parameter(torch.tensor([1.0, 0.0, 0.0]))
98
self.e = Parameter(torch.tensor([1.0, 0.0, 1.0]))
103
torch.save(model_a.state_dict(), self.TEST_DIR / "checkpoint_a.pt")
104
torch.save(model_b.state_dict(), self.TEST_DIR / "checkpoint_b.pt")
108
str(self.TEST_DIR / "checkpoint_a.pt"),
109
str(self.TEST_DIR / "checkpoint_b.pt"),
112
captured = capsys.readouterr()
114
_clean_output(captured.out)
121
!e, shape = (3,), distance = 0.5774
124
# NOTE: the difference value here of for 'e' of 0.5774 is currently
125
# calculated at the square root of the mean squared difference between 'e'
126
# in 'model_a' and 'e' in 'model_b':
127
# sqrt( (0^2 + 0^2 + 1^2) / 3 ) = sqrt( 1/3 ) = 0.5774
129
# Now call again with a higher theshold.
133
str(self.TEST_DIR / "checkpoint_a.pt"),
134
str(self.TEST_DIR / "checkpoint_b.pt"),
139
captured = capsys.readouterr()
141
_clean_output(captured.out)
152
# And call a third time with the same threshold but a higher scale.
156
str(self.TEST_DIR / "checkpoint_a.pt"),
157
str(self.TEST_DIR / "checkpoint_b.pt"),
164
captured = capsys.readouterr()
166
_clean_output(captured.out)
173
!e, shape = (3,), distance = 5.7735