vision

Форк
0
/
test_backbone_utils.py 
324 строки · 13.1 Кб
1
import random
2
from itertools import chain
3
from typing import Mapping, Sequence
4

5
import pytest
6
import torch
7
from common_utils import set_rng_seed
8
from torchvision import models
9
from torchvision.models._utils import IntermediateLayerGetter
10
from torchvision.models.detection.backbone_utils import BackboneWithFPN, mobilenet_backbone, resnet_fpn_backbone
11
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
12

13

14
@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
15
def test_resnet_fpn_backbone(backbone_name):
16
    x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu")
17
    model = resnet_fpn_backbone(backbone_name=backbone_name, weights=None)
18
    assert isinstance(model, BackboneWithFPN)
19
    y = model(x)
20
    assert list(y.keys()) == ["0", "1", "2", "3", "pool"]
21

22
    with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
23
        resnet_fpn_backbone(backbone_name=backbone_name, weights=None, trainable_layers=6)
24
    with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
25
        resnet_fpn_backbone(backbone_name=backbone_name, weights=None, returned_layers=[0, 1, 2, 3])
26
    with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
27
        resnet_fpn_backbone(backbone_name=backbone_name, weights=None, returned_layers=[2, 3, 4, 5])
28

29

30
@pytest.mark.parametrize("backbone_name", ("mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"))
31
def test_mobilenet_backbone(backbone_name):
32
    with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
33
        mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=False, trainable_layers=-1)
34
    with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
35
        mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True, returned_layers=[-1, 0, 1, 2])
36
    with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
37
        mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True, returned_layers=[3, 4, 5, 6])
38
    model_fpn = mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True)
39
    assert isinstance(model_fpn, BackboneWithFPN)
40
    model = mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=False)
41
    assert isinstance(model, torch.nn.Sequential)
42

43

44
# Needed by TestFxFeatureExtraction.test_leaf_module_and_function
45
def leaf_function(x):
46
    return int(x)
47

48

49
# Needed by TestFXFeatureExtraction. Checking that node naming conventions
50
# are respected. Particularly the index postfix of repeated node names
51
class TestSubModule(torch.nn.Module):
52
    def __init__(self):
53
        super().__init__()
54
        self.relu = torch.nn.ReLU()
55

56
    def forward(self, x):
57
        x = x + 1
58
        x = x + 1
59
        x = self.relu(x)
60
        x = self.relu(x)
61
        return x
62

63

64
class TestModule(torch.nn.Module):
65
    def __init__(self):
66
        super().__init__()
67
        self.submodule = TestSubModule()
68
        self.relu = torch.nn.ReLU()
69

70
    def forward(self, x):
71
        x = self.submodule(x)
72
        x = x + 1
73
        x = x + 1
74
        x = self.relu(x)
75
        x = self.relu(x)
76
        return x
77

78

79
test_module_nodes = [
80
    "x",
81
    "submodule.add",
82
    "submodule.add_1",
83
    "submodule.relu",
84
    "submodule.relu_1",
85
    "add",
86
    "add_1",
87
    "relu",
88
    "relu_1",
89
]
90

91

92
class TestFxFeatureExtraction:
93
    inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device="cpu")
94
    model_defaults = {"num_classes": 1}
95
    leaf_modules = []
96

97
    def _create_feature_extractor(self, *args, **kwargs):
98
        """
99
        Apply leaf modules
100
        """
101
        tracer_kwargs = {}
102
        if "tracer_kwargs" not in kwargs:
103
            tracer_kwargs = {"leaf_modules": self.leaf_modules}
104
        else:
105
            tracer_kwargs = kwargs.pop("tracer_kwargs")
106
        return create_feature_extractor(*args, **kwargs, tracer_kwargs=tracer_kwargs, suppress_diff_warning=True)
107

108
    def _get_return_nodes(self, model):
109
        set_rng_seed(0)
110
        exclude_nodes_filter = [
111
            "getitem",
112
            "floordiv",
113
            "size",
114
            "chunk",
115
            "_assert",
116
            "eq",
117
            "dim",
118
            "getattr",
119
        ]
120
        train_nodes, eval_nodes = get_graph_node_names(
121
            model, tracer_kwargs={"leaf_modules": self.leaf_modules}, suppress_diff_warning=True
122
        )
123
        # Get rid of any nodes that don't return tensors as they cause issues
124
        # when testing backward pass.
125
        train_nodes = [n for n in train_nodes if not any(x in n for x in exclude_nodes_filter)]
126
        eval_nodes = [n for n in eval_nodes if not any(x in n for x in exclude_nodes_filter)]
127
        return random.sample(train_nodes, 10), random.sample(eval_nodes, 10)
128

129
    @pytest.mark.parametrize("model_name", models.list_models(models))
130
    def test_build_fx_feature_extractor(self, model_name):
131
        set_rng_seed(0)
132
        model = models.get_model(model_name, **self.model_defaults).eval()
133
        train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
134
        # Check that it works with both a list and dict for return nodes
135
        self._create_feature_extractor(
136
            model, train_return_nodes={v: v for v in train_return_nodes}, eval_return_nodes=eval_return_nodes
137
        )
138
        self._create_feature_extractor(
139
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
140
        )
141
        # Check must specify return nodes
142
        with pytest.raises(ValueError):
143
            self._create_feature_extractor(model)
144
        # Check return_nodes and train_return_nodes / eval_return nodes
145
        # mutual exclusivity
146
        with pytest.raises(ValueError):
147
            self._create_feature_extractor(
148
                model, return_nodes=train_return_nodes, train_return_nodes=train_return_nodes
149
            )
150
        # Check train_return_nodes / eval_return nodes must both be specified
151
        with pytest.raises(ValueError):
152
            self._create_feature_extractor(model, train_return_nodes=train_return_nodes)
153
        # Check invalid node name raises ValueError
154
        with pytest.raises(ValueError):
155
            # First just double check that this node really doesn't exist
156
            if not any(n.startswith("l") or n.startswith("l.") for n in chain(train_return_nodes, eval_return_nodes)):
157
                self._create_feature_extractor(model, train_return_nodes=["l"], eval_return_nodes=["l"])
158
            else:  # otherwise skip this check
159
                raise ValueError
160

161
    def test_node_name_conventions(self):
162
        model = TestModule()
163
        train_nodes, _ = get_graph_node_names(model)
164
        assert all(a == b for a, b in zip(train_nodes, test_module_nodes))
165

166
    @pytest.mark.parametrize("model_name", models.list_models(models))
167
    def test_forward_backward(self, model_name):
168
        model = models.get_model(model_name, **self.model_defaults).train()
169
        train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
170
        model = self._create_feature_extractor(
171
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
172
        )
173
        out = model(self.inp)
174
        out_agg = 0
175
        for node_out in out.values():
176
            if isinstance(node_out, Sequence):
177
                out_agg += sum(o.float().mean() for o in node_out if o is not None)
178
            elif isinstance(node_out, Mapping):
179
                out_agg += sum(o.float().mean() for o in node_out.values() if o is not None)
180
            else:
181
                # Assume that the only other alternative at this point is a Tensor
182
                out_agg += node_out.float().mean()
183
        out_agg.backward()
184

185
    def test_feature_extraction_methods_equivalence(self):
186
        model = models.resnet18(**self.model_defaults).eval()
187
        return_layers = {"layer1": "layer1", "layer2": "layer2", "layer3": "layer3", "layer4": "layer4"}
188

189
        ilg_model = IntermediateLayerGetter(model, return_layers).eval()
190
        fx_model = self._create_feature_extractor(model, return_layers)
191

192
        # Check that we have same parameters
193
        for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(), fx_model.named_parameters()):
194
            assert n1 == n2
195
            assert p1.equal(p2)
196

197
        # And that outputs match
198
        with torch.no_grad():
199
            ilg_out = ilg_model(self.inp)
200
            fgn_out = fx_model(self.inp)
201
        assert all(k1 == k2 for k1, k2 in zip(ilg_out.keys(), fgn_out.keys()))
202
        for k in ilg_out.keys():
203
            assert ilg_out[k].equal(fgn_out[k])
204

205
    @pytest.mark.parametrize("model_name", models.list_models(models))
206
    def test_jit_forward_backward(self, model_name):
207
        set_rng_seed(0)
208
        model = models.get_model(model_name, **self.model_defaults).train()
209
        train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
210
        model = self._create_feature_extractor(
211
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
212
        )
213
        model = torch.jit.script(model)
214
        fgn_out = model(self.inp)
215
        out_agg = 0
216
        for node_out in fgn_out.values():
217
            if isinstance(node_out, Sequence):
218
                out_agg += sum(o.float().mean() for o in node_out if o is not None)
219
            elif isinstance(node_out, Mapping):
220
                out_agg += sum(o.float().mean() for o in node_out.values() if o is not None)
221
            else:
222
                # Assume that the only other alternative at this point is a Tensor
223
                out_agg += node_out.float().mean()
224
        out_agg.backward()
225

226
    def test_train_eval(self):
227
        class TestModel(torch.nn.Module):
228
            def __init__(self):
229
                super().__init__()
230
                self.dropout = torch.nn.Dropout(p=1.0)
231

232
            def forward(self, x):
233
                x = x.float().mean()
234
                x = self.dropout(x)  # dropout
235
                if self.training:
236
                    x += 100  # add
237
                else:
238
                    x *= 0  # mul
239
                x -= 0  # sub
240
                return x
241

242
        model = TestModel()
243

244
        train_return_nodes = ["dropout", "add", "sub"]
245
        eval_return_nodes = ["dropout", "mul", "sub"]
246

247
        def checks(model, mode):
248
            with torch.no_grad():
249
                out = model(torch.ones(10, 10))
250
            if mode == "train":
251
                # Check that dropout is respected
252
                assert out["dropout"].item() == 0
253
                # Check that control flow dependent on training_mode is respected
254
                assert out["sub"].item() == 100
255
                assert "add" in out
256
                assert "mul" not in out
257
            elif mode == "eval":
258
                # Check that dropout is respected
259
                assert out["dropout"].item() == 1
260
                # Check that control flow dependent on training_mode is respected
261
                assert out["sub"].item() == 0
262
                assert "mul" in out
263
                assert "add" not in out
264

265
        # Starting from train mode
266
        model.train()
267
        fx_model = self._create_feature_extractor(
268
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
269
        )
270
        # Check that the models stay in their original training state
271
        assert model.training
272
        assert fx_model.training
273
        # Check outputs
274
        checks(fx_model, "train")
275
        # Check outputs after switching to eval mode
276
        fx_model.eval()
277
        checks(fx_model, "eval")
278

279
        # Starting from eval mode
280
        model.eval()
281
        fx_model = self._create_feature_extractor(
282
            model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
283
        )
284
        # Check that the models stay in their original training state
285
        assert not model.training
286
        assert not fx_model.training
287
        # Check outputs
288
        checks(fx_model, "eval")
289
        # Check outputs after switching to train mode
290
        fx_model.train()
291
        checks(fx_model, "train")
292

293
    def test_leaf_module_and_function(self):
294
        class LeafModule(torch.nn.Module):
295
            def forward(self, x):
296
                # This would raise a TypeError if it were not in a leaf module
297
                int(x.shape[0])
298
                return torch.nn.functional.relu(x + 4)
299

300
        class TestModule(torch.nn.Module):
301
            def __init__(self):
302
                super().__init__()
303
                self.conv = torch.nn.Conv2d(3, 1, 3)
304
                self.leaf_module = LeafModule()
305

306
            def forward(self, x):
307
                leaf_function(x.shape[0])
308
                x = self.conv(x)
309
                return self.leaf_module(x)
310

311
        model = self._create_feature_extractor(
312
            TestModule(),
313
            return_nodes=["leaf_module"],
314
            tracer_kwargs={"leaf_modules": [LeafModule], "autowrap_functions": [leaf_function]},
315
        ).train()
316

317
        # Check that LeafModule is not in the list of nodes
318
        assert "relu" not in [str(n) for n in model.graph.nodes]
319
        assert "leaf_module" in [str(n) for n in model.graph.nodes]
320

321
        # Check forward
322
        out = model(self.inp)
323
        # And backward
324
        out["leaf_module"].float().mean().backward()
325

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.