2
from itertools import chain
3
from typing import Mapping, Sequence
7
from common_utils import set_rng_seed
8
from torchvision import models
9
from torchvision.models._utils import IntermediateLayerGetter
10
from torchvision.models.detection.backbone_utils import BackboneWithFPN, mobilenet_backbone, resnet_fpn_backbone
11
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
14
@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
15
def test_resnet_fpn_backbone(backbone_name):
16
x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu")
17
model = resnet_fpn_backbone(backbone_name=backbone_name, weights=None)
18
assert isinstance(model, BackboneWithFPN)
20
assert list(y.keys()) == ["0", "1", "2", "3", "pool"]
22
with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
23
resnet_fpn_backbone(backbone_name=backbone_name, weights=None, trainable_layers=6)
24
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
25
resnet_fpn_backbone(backbone_name=backbone_name, weights=None, returned_layers=[0, 1, 2, 3])
26
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
27
resnet_fpn_backbone(backbone_name=backbone_name, weights=None, returned_layers=[2, 3, 4, 5])
30
@pytest.mark.parametrize("backbone_name", ("mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"))
31
def test_mobilenet_backbone(backbone_name):
32
with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
33
mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=False, trainable_layers=-1)
34
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
35
mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True, returned_layers=[-1, 0, 1, 2])
36
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
37
mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True, returned_layers=[3, 4, 5, 6])
38
model_fpn = mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=True)
39
assert isinstance(model_fpn, BackboneWithFPN)
40
model = mobilenet_backbone(backbone_name=backbone_name, weights=None, fpn=False)
41
assert isinstance(model, torch.nn.Sequential)
51
class TestSubModule(torch.nn.Module):
54
self.relu = torch.nn.ReLU()
64
class TestModule(torch.nn.Module):
67
self.submodule = TestSubModule()
68
self.relu = torch.nn.ReLU()
92
class TestFxFeatureExtraction:
93
inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device="cpu")
94
model_defaults = {"num_classes": 1}
97
def _create_feature_extractor(self, *args, **kwargs):
102
if "tracer_kwargs" not in kwargs:
103
tracer_kwargs = {"leaf_modules": self.leaf_modules}
105
tracer_kwargs = kwargs.pop("tracer_kwargs")
106
return create_feature_extractor(*args, **kwargs, tracer_kwargs=tracer_kwargs, suppress_diff_warning=True)
108
def _get_return_nodes(self, model):
110
exclude_nodes_filter = [
120
train_nodes, eval_nodes = get_graph_node_names(
121
model, tracer_kwargs={"leaf_modules": self.leaf_modules}, suppress_diff_warning=True
125
train_nodes = [n for n in train_nodes if not any(x in n for x in exclude_nodes_filter)]
126
eval_nodes = [n for n in eval_nodes if not any(x in n for x in exclude_nodes_filter)]
127
return random.sample(train_nodes, 10), random.sample(eval_nodes, 10)
129
@pytest.mark.parametrize("model_name", models.list_models(models))
130
def test_build_fx_feature_extractor(self, model_name):
132
model = models.get_model(model_name, **self.model_defaults).eval()
133
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
135
self._create_feature_extractor(
136
model, train_return_nodes={v: v for v in train_return_nodes}, eval_return_nodes=eval_return_nodes
138
self._create_feature_extractor(
139
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
142
with pytest.raises(ValueError):
143
self._create_feature_extractor(model)
146
with pytest.raises(ValueError):
147
self._create_feature_extractor(
148
model, return_nodes=train_return_nodes, train_return_nodes=train_return_nodes
151
with pytest.raises(ValueError):
152
self._create_feature_extractor(model, train_return_nodes=train_return_nodes)
154
with pytest.raises(ValueError):
156
if not any(n.startswith("l") or n.startswith("l.") for n in chain(train_return_nodes, eval_return_nodes)):
157
self._create_feature_extractor(model, train_return_nodes=["l"], eval_return_nodes=["l"])
161
def test_node_name_conventions(self):
163
train_nodes, _ = get_graph_node_names(model)
164
assert all(a == b for a, b in zip(train_nodes, test_module_nodes))
166
@pytest.mark.parametrize("model_name", models.list_models(models))
167
def test_forward_backward(self, model_name):
168
model = models.get_model(model_name, **self.model_defaults).train()
169
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
170
model = self._create_feature_extractor(
171
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
173
out = model(self.inp)
175
for node_out in out.values():
176
if isinstance(node_out, Sequence):
177
out_agg += sum(o.float().mean() for o in node_out if o is not None)
178
elif isinstance(node_out, Mapping):
179
out_agg += sum(o.float().mean() for o in node_out.values() if o is not None)
182
out_agg += node_out.float().mean()
185
def test_feature_extraction_methods_equivalence(self):
186
model = models.resnet18(**self.model_defaults).eval()
187
return_layers = {"layer1": "layer1", "layer2": "layer2", "layer3": "layer3", "layer4": "layer4"}
189
ilg_model = IntermediateLayerGetter(model, return_layers).eval()
190
fx_model = self._create_feature_extractor(model, return_layers)
193
for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(), fx_model.named_parameters()):
198
with torch.no_grad():
199
ilg_out = ilg_model(self.inp)
200
fgn_out = fx_model(self.inp)
201
assert all(k1 == k2 for k1, k2 in zip(ilg_out.keys(), fgn_out.keys()))
202
for k in ilg_out.keys():
203
assert ilg_out[k].equal(fgn_out[k])
205
@pytest.mark.parametrize("model_name", models.list_models(models))
206
def test_jit_forward_backward(self, model_name):
208
model = models.get_model(model_name, **self.model_defaults).train()
209
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
210
model = self._create_feature_extractor(
211
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
213
model = torch.jit.script(model)
214
fgn_out = model(self.inp)
216
for node_out in fgn_out.values():
217
if isinstance(node_out, Sequence):
218
out_agg += sum(o.float().mean() for o in node_out if o is not None)
219
elif isinstance(node_out, Mapping):
220
out_agg += sum(o.float().mean() for o in node_out.values() if o is not None)
223
out_agg += node_out.float().mean()
226
def test_train_eval(self):
227
class TestModel(torch.nn.Module):
230
self.dropout = torch.nn.Dropout(p=1.0)
232
def forward(self, x):
244
train_return_nodes = ["dropout", "add", "sub"]
245
eval_return_nodes = ["dropout", "mul", "sub"]
247
def checks(model, mode):
248
with torch.no_grad():
249
out = model(torch.ones(10, 10))
252
assert out["dropout"].item() == 0
254
assert out["sub"].item() == 100
256
assert "mul" not in out
259
assert out["dropout"].item() == 1
261
assert out["sub"].item() == 0
263
assert "add" not in out
267
fx_model = self._create_feature_extractor(
268
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
271
assert model.training
272
assert fx_model.training
274
checks(fx_model, "train")
277
checks(fx_model, "eval")
281
fx_model = self._create_feature_extractor(
282
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes
285
assert not model.training
286
assert not fx_model.training
288
checks(fx_model, "eval")
291
checks(fx_model, "train")
293
def test_leaf_module_and_function(self):
294
class LeafModule(torch.nn.Module):
295
def forward(self, x):
298
return torch.nn.functional.relu(x + 4)
300
class TestModule(torch.nn.Module):
303
self.conv = torch.nn.Conv2d(3, 1, 3)
304
self.leaf_module = LeafModule()
306
def forward(self, x):
307
leaf_function(x.shape[0])
309
return self.leaf_module(x)
311
model = self._create_feature_extractor(
313
return_nodes=["leaf_module"],
314
tracer_kwargs={"leaf_modules": [LeafModule], "autowrap_functions": [leaf_function]},
318
assert "relu" not in [str(n) for n in model.graph.nodes]
319
assert "leaf_module" in [str(n) for n in model.graph.nodes]
322
out = model(self.inp)
324
out["leaf_module"].float().mean().backward()