1
# Owner(s): ["oncall: mobile"]
5
from torch.nn import functional as F
7
from torch.testing._internal.common_utils import TestCase, run_tests
8
from torch.testing import FileCheck
11
@unittest.skipUnless(torch.is_vulkan_available(),
12
"Vulkan backend must be available for these tests.")
13
class TestVulkanRewritePass(TestCase):
15
def validate_transformed_module(
20
prepack_removal=False,
21
fuse_clamping_ops=False):
22
module_instance = self
23
scripted_model = torch.jit.script(module_instance)
25
input_data = torch.normal(1, 20, size=data_shape)
26
ref_result = scripted_model(input_data)
27
torch._C._jit_pass_vulkan_insert_prepacked_ops(scripted_model._c)
28
if fuse_clamping_ops or prepack_removal:
29
scripted_model._c = torch._C._freeze_module(scripted_model._c)
31
torch._C._jit_pass_vulkan_fuse_clamp_w_prepacked_conv(scripted_model._c)
33
torch._C._jit_pass_vulkan_fold_prepacking_ops(scripted_model._c)
36
torch.jit.save(scripted_model, buffer)
38
deserialized_scripted_model = torch.jit.load(buffer)
39
for pattern, v in pattern_count_map.items():
41
FileCheck().check(pattern).run(deserialized_scripted_model.graph)
43
FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
45
FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
50
input_channels_per_group = 6
53
output_channels_per_group = 6
55
kernel_h = kernel_w = 3
56
stride_h = stride_w = 1
59
input_channels = input_channels_per_group * groups
60
output_channels = output_channels_per_group * groups
61
kernels = (kernel_h, kernel_w)
62
strides = (stride_h, stride_w)
63
paddings = (pad_h, pad_w)
64
dilations = (dilation, dilation)
65
conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
66
conv_bias_shape = (output_channels)
68
class Conv2D(torch.nn.Module):
71
self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
72
self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
73
self.strides = strides
74
self.paddings = paddings
75
self.dilations = dilations
79
return F.conv2d(x, self.weight, self.bias,
80
self.strides, self.paddings, self.dilations, self.groups)
82
data_shape = (batch_size, input_channels, height, width)
83
pattern_count_map = {"Tensor = aten::conv2d": -1,
84
"vulkan_prepack::conv2d_clamp_prepack": 1,
85
"vulkan_prepack::conv2d_clamp_run": 1}
86
TestVulkanRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
88
class Conv2DRelu(torch.nn.Module):
91
self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
92
self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
93
self.strides = strides
94
self.paddings = paddings
95
self.dilations = dilations
99
o = F.conv2d(x, self.weight, self.bias,
100
self.strides, self.paddings, self.dilations, self.groups)
104
data_shape = (batch_size, input_channels, height, width)
105
pattern_count_map = {"Tensor = aten::conv2d": -1,
106
"vulkan_prepack::conv2d_clamp_prepack": 1,
107
"vulkan_prepack::conv2d_clamp_run": 1}
108
TestVulkanRewritePass.validate_transformed_module(
109
Conv2DRelu(), pattern_count_map, data_shape)
111
pattern_count_map["aten::relu"] = 1
112
pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1
113
TestVulkanRewritePass.validate_transformed_module(
117
prepack_removal=True)
118
pattern_count_map["aten::relu"] = -1
119
TestVulkanRewritePass.validate_transformed_module(
123
prepack_removal=True,
124
fuse_clamping_ops=True)
127
class Conv2DHardtanh(torch.nn.Module):
130
self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
131
self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
132
self.strides = strides
133
self.paddings = paddings
134
self.dilations = dilations
137
def forward(self, x):
138
o = F.conv2d(x, self.weight, self.bias,
139
self.strides, self.paddings, self.dilations, self.groups)
143
data_shape = (batch_size, input_channels, height, width)
144
pattern_count_map = {"Tensor = aten::conv2d": -1,
145
"vulkan_prepack::conv2d_clamp_prepack": 1,
146
"vulkan_prepack::conv2d_clamp_run": 1}
147
TestVulkanRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape)
148
pattern_count_map["aten::hardtanh"] = 1
149
pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1
150
TestVulkanRewritePass.validate_transformed_module(
154
prepack_removal=True)
155
pattern_count_map["aten::hardtanh"] = -1
156
TestVulkanRewritePass.validate_transformed_module(
160
prepack_removal=True,
161
fuse_clamping_ops=True)
163
if __name__ == "__main__":