pytorch

Форк
0
/
test_fx_param_shape_control_flow.py 
161 строка · 5.0 Кб
1
# Owner(s): ["module: fx"]
2

3
import unittest
4

5
import torch
6
import torch.fx
7
from torch.testing._internal.common_utils import TestCase
8

9

10
class MyModuleBase(torch.nn.Module):
11
    def forward(self, x):
12
        matrx = self.get_mul_matrix()
13
        if self.no_relu():
14
            return torch.mm(x, matrx)
15
        else:
16
            return torch.relu(torch.mm(x, matrx))
17

18
    def get_mul_matrix(self):
19
        return self.param
20

21
    def no_relu(self):
22
        raise Exception("not implemented")  # noqa: TRY002
23

24

25
class MyModuleParamShape(MyModuleBase):
26
    def __init__(self, in_channels):
27
        super().__init__()
28
        self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
29

30
    def no_relu(self):
31
        return self.param.shape[0] < 10
32

33

34
class MyModuleParamSize(MyModuleBase):
35
    def __init__(self, in_channels):
36
        super().__init__()
37
        self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
38

39
    def no_relu(self):
40
        return self.param.size()[0] < 10
41

42

43
class MyModuleParamDim(MyModuleBase):
44
    def __init__(self, param):
45
        super().__init__()
46
        self.param = param
47

48
    def get_mul_matrix(self):
49
        return self.param[0] if (self.param.dim() == 3) else self.param
50

51
    def no_relu(self):
52
        return self.param.dim() == 3
53

54

55
class MyModuleParamNDim(MyModuleBase):
56
    def __init__(self, param):
57
        super().__init__()
58
        self.param = param
59

60
    def get_mul_matrix(self):
61
        return self.param[0] if (self.param.ndim == 3) else self.param
62

63
    def no_relu(self):
64
        return self.param.ndim == 3
65

66

67
class MyModuleParamNumEl(MyModuleBase):
68
    def __init__(self, in_channels):
69
        super().__init__()
70
        self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
71

72
    def no_relu(self):
73
        return self.param.numel() < 10 * 3
74

75

76
class MyModuleParamNElement(MyModuleBase):
77
    def __init__(self, in_channels):
78
        super().__init__()
79
        self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
80

81
    def no_relu(self):
82
        return self.param.nelement() < 10 * 3
83

84

85
class TestConstParamShapeInControlFlow(TestCase):
86
    def verify_mm_relu_mods(self, mm_only_mod, relu_mod):
87
        """
88
        Verify one module only does a mm op while the other
89
        performs both mm and relu ops in cascade
90
        """
91
        x = torch.randn(10, 5)
92
        torch.testing.assert_close(
93
            mm_only_mod(x), torch.mm(x, mm_only_mod.get_mul_matrix())
94
        )
95
        tracer = torch.fx.Tracer(param_shapes_constant=True)
96
        traced_graph = tracer.trace(mm_only_mod)
97

98
        # verify the graph module calculates the same result
99
        graph_mod_mm = torch.fx.GraphModule(mm_only_mod, traced_graph)
100
        torch.testing.assert_close(
101
            graph_mod_mm(x), torch.mm(x, mm_only_mod.get_mul_matrix())
102
        )
103

104
        # Make a new module with different parameter shape to go down the different
105
        # code path
106
        x = torch.randn(10, 15)
107
        torch.testing.assert_close(
108
            relu_mod(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix()))
109
        )
110

111
        tracer2 = torch.fx.Tracer(param_shapes_constant=True)
112
        traced_graph2 = tracer2.trace(relu_mod)
113

114
        # verify the graph module calculates the same result
115
        graph_mod_relu = torch.fx.GraphModule(relu_mod, traced_graph2)
116
        torch.testing.assert_close(
117
            graph_mod_relu(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix()))
118
        )
119

120
        graph1_node_targets = [n.target for n in traced_graph.nodes]
121
        graph2_node_targets = [n.target for n in traced_graph2.nodes]
122

123
        # the second graph has an exta relu function call node
124
        assert torch.mm in graph1_node_targets and torch.mm in graph2_node_targets
125
        assert (
126
            torch.relu not in graph1_node_targets and torch.relu in graph2_node_targets
127
        )
128

129
    def test_param_shape_const(self):
130
        mymod = MyModuleParamShape(in_channels=5)
131
        mymod2 = MyModuleParamShape(in_channels=15)
132
        self.verify_mm_relu_mods(mymod, mymod2)
133

134
    def test_param_size_const(self):
135
        mymod = MyModuleParamSize(in_channels=5)
136
        mymod2 = MyModuleParamSize(in_channels=15)
137
        self.verify_mm_relu_mods(mymod, mymod2)
138

139
    def test_param_dim_const(self):
140
        mymod = MyModuleParamDim(torch.nn.Parameter(torch.randn(2, 5, 3)))
141
        mymod2 = MyModuleParamDim(torch.nn.Parameter(torch.randn(15, 3)))
142
        self.verify_mm_relu_mods(mymod, mymod2)
143

144
    def test_param_ndim_const(self):
145
        mymod = MyModuleParamNDim(torch.nn.Parameter(torch.randn(2, 5, 3)))
146
        mymod2 = MyModuleParamNDim(torch.nn.Parameter(torch.randn(15, 3)))
147
        self.verify_mm_relu_mods(mymod, mymod2)
148

149
    def test_param_numel_const(self):
150
        mymod = MyModuleParamNumEl(in_channels=5)
151
        mymod2 = MyModuleParamNumEl(in_channels=15)
152
        self.verify_mm_relu_mods(mymod, mymod2)
153

154
    def test_param_nelement_const(self):
155
        mymod = MyModuleParamNElement(in_channels=5)
156
        mymod2 = MyModuleParamNElement(in_channels=15)
157
        self.verify_mm_relu_mods(mymod, mymod2)
158

159

160
if __name__ == "__main__":
161
    unittest.main()
162

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.