15
def _get_test_image_tensor():
16
data_dir = os.path.join(os.path.dirname(__file__), "assets")
17
img_path = os.path.join(data_dir, "grace_hopper_517x606.jpg")
18
input_image = PIL.Image.open(img_path)
20
preprocess = torchvision.transforms.Compose(
22
torchvision.transforms.Resize(256),
23
torchvision.transforms.CenterCrop(224),
24
torchvision.transforms.ToTensor(),
25
torchvision.transforms.Normalize(
26
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
30
return preprocess(input_image).unsqueeze(0)
34
class _TopPredictor(nn.Module):
35
def __init__(self, base_model):
37
self.base_model = base_model
40
x = self.base_model(x)
41
_, topk_id = torch.topk(x[0], 1)
48
@parameterized.parameterized_class(
51
class_name_func=onnx_test_common.parameterize_class_name,
53
class TestQuantizedModelsONNXRuntime(onnx_test_common._TestONNXRuntime):
54
def run_test(self, model, inputs, *args, **kwargs):
55
model = _TopPredictor(model)
56
return super().run_test(model, inputs, *args, **kwargs)
58
def test_mobilenet_v3(self):
59
model = torchvision.models.quantization.mobilenet_v3_large(
60
pretrained=True, quantize=True
62
self.run_test(model, _get_test_image_tensor())
64
@unittest.skip("quantized::cat not supported")
65
def test_inception_v3(self):
66
model = torchvision.models.quantization.inception_v3(
67
pretrained=True, quantize=True
69
self.run_test(model, _get_test_image_tensor())
71
@unittest.skip("quantized::cat not supported")
72
def test_googlenet(self):
73
model = torchvision.models.quantization.googlenet(
74
pretrained=True, quantize=True
76
self.run_test(model, _get_test_image_tensor())
78
@unittest.skip("quantized::cat not supported")
79
def test_shufflenet_v2_x0_5(self):
80
model = torchvision.models.quantization.shufflenet_v2_x0_5(
81
pretrained=True, quantize=True
83
self.run_test(model, _get_test_image_tensor())
85
def test_resnet18(self):
86
model = torchvision.models.quantization.resnet18(pretrained=True, quantize=True)
87
self.run_test(model, _get_test_image_tensor())
89
def test_resnet50(self):
90
model = torchvision.models.quantization.resnet50(pretrained=True, quantize=True)
91
self.run_test(model, _get_test_image_tensor())
93
def test_resnext101_32x8d(self):
94
model = torchvision.models.quantization.resnext101_32x8d(
95
pretrained=True, quantize=True
97
self.run_test(model, _get_test_image_tensor())