pytorch

Форк
0
/
test_vulkan.py 
164 строки · 6.7 Кб
1
# Owner(s): ["oncall: mobile"]
2

3
import unittest
4
import torch
5
from torch.nn import functional as F
6

7
from torch.testing._internal.common_utils import TestCase, run_tests
8
from torch.testing import FileCheck
9
import io
10

11
@unittest.skipUnless(torch.is_vulkan_available(),
12
                     "Vulkan backend must be available for these tests.")
13
class TestVulkanRewritePass(TestCase):
14
    @staticmethod
15
    def validate_transformed_module(
16
            # To please flake
17
            self,
18
            pattern_count_map,
19
            data_shape,
20
            prepack_removal=False,
21
            fuse_clamping_ops=False):
22
        module_instance = self
23
        scripted_model = torch.jit.script(module_instance)
24
        scripted_model.eval()
25
        input_data = torch.normal(1, 20, size=data_shape)
26
        ref_result = scripted_model(input_data)
27
        torch._C._jit_pass_vulkan_insert_prepacked_ops(scripted_model._c)
28
        if fuse_clamping_ops or prepack_removal:
29
            scripted_model._c = torch._C._freeze_module(scripted_model._c)
30
        if fuse_clamping_ops:
31
            torch._C._jit_pass_vulkan_fuse_clamp_w_prepacked_conv(scripted_model._c)
32
        if prepack_removal:
33
            torch._C._jit_pass_vulkan_fold_prepacking_ops(scripted_model._c)
34

35
        buffer = io.BytesIO()
36
        torch.jit.save(scripted_model, buffer)
37
        buffer.seek(0)
38
        deserialized_scripted_model = torch.jit.load(buffer)
39
        for pattern, v in pattern_count_map.items():
40
            if (v == 0):
41
                FileCheck().check(pattern).run(deserialized_scripted_model.graph)
42
            elif (v == -1):
43
                FileCheck().check_not(pattern).run(deserialized_scripted_model.graph)
44
            else:
45
                FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph)
46

47
    def test_conv(self):
48
        # Conv params
49
        batch_size = 2
50
        input_channels_per_group = 6
51
        height = 16
52
        width = 16
53
        output_channels_per_group = 6
54
        groups = 4
55
        kernel_h = kernel_w = 3
56
        stride_h = stride_w = 1
57
        pad_h = pad_w = 1
58
        dilation = 1
59
        input_channels = input_channels_per_group * groups
60
        output_channels = output_channels_per_group * groups
61
        kernels = (kernel_h, kernel_w)
62
        strides = (stride_h, stride_w)
63
        paddings = (pad_h, pad_w)
64
        dilations = (dilation, dilation)
65
        conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w)
66
        conv_bias_shape = (output_channels)
67

68
        class Conv2D(torch.nn.Module):
69
            def __init__(self) -> None:
70
                super().__init__()
71
                self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
72
                self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
73
                self.strides = strides
74
                self.paddings = paddings
75
                self.dilations = dilations
76
                self.groups = groups
77

78
            def forward(self, x):
79
                return F.conv2d(x, self.weight, self.bias,
80
                                self.strides, self.paddings, self.dilations, self.groups)
81

82
        data_shape = (batch_size, input_channels, height, width)
83
        pattern_count_map = {"Tensor = aten::conv2d": -1,
84
                             "vulkan_prepack::conv2d_clamp_prepack": 1,
85
                             "vulkan_prepack::conv2d_clamp_run": 1}
86
        TestVulkanRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape)
87

88
        class Conv2DRelu(torch.nn.Module):
89
            def __init__(self) -> None:
90
                super().__init__()
91
                self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
92
                self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
93
                self.strides = strides
94
                self.paddings = paddings
95
                self.dilations = dilations
96
                self.groups = groups
97

98
            def forward(self, x):
99
                o = F.conv2d(x, self.weight, self.bias,
100
                             self.strides, self.paddings, self.dilations, self.groups)
101
                o = F.relu(o)
102
                return o
103

104
        data_shape = (batch_size, input_channels, height, width)
105
        pattern_count_map = {"Tensor = aten::conv2d": -1,
106
                             "vulkan_prepack::conv2d_clamp_prepack": 1,
107
                             "vulkan_prepack::conv2d_clamp_run": 1}
108
        TestVulkanRewritePass.validate_transformed_module(
109
            Conv2DRelu(), pattern_count_map, data_shape)
110

111
        pattern_count_map["aten::relu"] = 1
112
        pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1
113
        TestVulkanRewritePass.validate_transformed_module(
114
            Conv2DRelu(),
115
            pattern_count_map,
116
            data_shape,
117
            prepack_removal=True)
118
        pattern_count_map["aten::relu"] = -1
119
        TestVulkanRewritePass.validate_transformed_module(
120
            Conv2DRelu(),
121
            pattern_count_map,
122
            data_shape,
123
            prepack_removal=True,
124
            fuse_clamping_ops=True)
125

126

127
        class Conv2DHardtanh(torch.nn.Module):
128
            def __init__(self) -> None:
129
                super().__init__()
130
                self.weight = torch.nn.Parameter(torch.rand(conv_weight_shape), requires_grad=False)
131
                self.bias = torch.nn.Parameter(torch.rand(conv_bias_shape), requires_grad=False)
132
                self.strides = strides
133
                self.paddings = paddings
134
                self.dilations = dilations
135
                self.groups = groups
136

137
            def forward(self, x):
138
                o = F.conv2d(x, self.weight, self.bias,
139
                             self.strides, self.paddings, self.dilations, self.groups)
140
                o = F.hardtanh(o)
141
                return o
142

143
        data_shape = (batch_size, input_channels, height, width)
144
        pattern_count_map = {"Tensor = aten::conv2d": -1,
145
                             "vulkan_prepack::conv2d_clamp_prepack": 1,
146
                             "vulkan_prepack::conv2d_clamp_run": 1}
147
        TestVulkanRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape)
148
        pattern_count_map["aten::hardtanh"] = 1
149
        pattern_count_map["vulkan_prepack::conv2d_clamp_prepack"] = -1
150
        TestVulkanRewritePass.validate_transformed_module(
151
            Conv2DHardtanh(),
152
            pattern_count_map,
153
            data_shape,
154
            prepack_removal=True)
155
        pattern_count_map["aten::hardtanh"] = -1
156
        TestVulkanRewritePass.validate_transformed_module(
157
            Conv2DRelu(),
158
            pattern_count_map,
159
            data_shape,
160
            prepack_removal=True,
161
            fuse_clamping_ops=True)
162

163
if __name__ == "__main__":
164
    run_tests()
165

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.