1
# Owner(s): ["module: fx"]
7
from torch.testing._internal.common_utils import TestCase
10
class MyModuleBase(torch.nn.Module):
12
matrx = self.get_mul_matrix()
14
return torch.mm(x, matrx)
16
return torch.relu(torch.mm(x, matrx))
18
def get_mul_matrix(self):
22
raise Exception("not implemented")
24
class MyModuleParamShape(MyModuleBase):
25
def __init__(self, in_channels):
27
self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
30
return self.param.shape[0] < 10
33
class MyModuleParamSize(MyModuleBase):
34
def __init__(self, in_channels):
36
self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
39
return self.param.size()[0] < 10
42
class MyModuleParamDim(MyModuleBase):
43
def __init__(self, param):
47
def get_mul_matrix(self):
48
return self.param[0] if (self.param.dim() == 3) else self.param
51
return self.param.dim() == 3
54
class MyModuleParamNDim(MyModuleBase):
55
def __init__(self, param):
59
def get_mul_matrix(self):
60
return self.param[0] if (self.param.ndim == 3) else self.param
63
return self.param.ndim == 3
66
class MyModuleParamNumEl(MyModuleBase):
67
def __init__(self, in_channels):
69
self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
72
return self.param.numel() < 10 * 3
76
class MyModuleParamNElement(MyModuleBase):
77
def __init__(self, in_channels):
79
self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
82
return self.param.nelement() < 10 * 3
86
class TestConstParamShapeInControlFlow(TestCase):
88
def verify_mm_relu_mods(self, mm_only_mod, relu_mod):
90
Verify one module only does a mm op while the other
91
performs both mm and relu ops in cascade
93
x = torch.randn(10, 5)
94
torch.testing.assert_close(mm_only_mod(x), torch.mm(x, mm_only_mod.get_mul_matrix()))
95
tracer = torch.fx.Tracer(param_shapes_constant=True)
96
traced_graph = tracer.trace(mm_only_mod)
98
# verify the graph module calculates the same result
99
graph_mod_mm = torch.fx.GraphModule(mm_only_mod, traced_graph)
100
torch.testing.assert_close(graph_mod_mm(x), torch.mm(x, mm_only_mod.get_mul_matrix()))
103
# Make a new module with different parameter shape to go down the different
105
x = torch.randn(10, 15)
106
torch.testing.assert_close(relu_mod(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix())))
108
tracer2 = torch.fx.Tracer(param_shapes_constant=True)
109
traced_graph2 = tracer2.trace(relu_mod)
111
# verify the graph module calculates the same result
112
graph_mod_relu = torch.fx.GraphModule(relu_mod, traced_graph2)
113
torch.testing.assert_close(graph_mod_relu(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix())))
116
graph1_node_targets = [n.target for n in traced_graph.nodes]
117
graph2_node_targets = [n.target for n in traced_graph2.nodes]
119
# the second graph has an exta relu function call node
120
assert torch.mm in graph1_node_targets and torch.mm in graph2_node_targets
121
assert torch.relu not in graph1_node_targets and torch.relu in graph2_node_targets
123
def test_param_shape_const(self):
124
mymod = MyModuleParamShape(in_channels=5)
125
mymod2 = MyModuleParamShape(in_channels=15)
126
self.verify_mm_relu_mods(mymod, mymod2)
128
def test_param_size_const(self):
129
mymod = MyModuleParamSize(in_channels=5)
130
mymod2 = MyModuleParamSize(in_channels=15)
131
self.verify_mm_relu_mods(mymod, mymod2)
133
def test_param_dim_const(self):
134
mymod = MyModuleParamDim(torch.nn.Parameter(torch.randn(2, 5, 3)))
135
mymod2 = MyModuleParamDim(torch.nn.Parameter(torch.randn(15, 3)))
136
self.verify_mm_relu_mods(mymod, mymod2)
138
def test_param_ndim_const(self):
139
mymod = MyModuleParamNDim(torch.nn.Parameter(torch.randn(2, 5, 3)))
140
mymod2 = MyModuleParamNDim(torch.nn.Parameter(torch.randn(15, 3)))
141
self.verify_mm_relu_mods(mymod, mymod2)
143
def test_param_numel_const(self):
144
mymod = MyModuleParamNumEl(in_channels=5)
145
mymod2 = MyModuleParamNumEl(in_channels=15)
146
self.verify_mm_relu_mods(mymod, mymod2)
148
def test_param_nelement_const(self):
149
mymod = MyModuleParamNElement(in_channels=5)
150
mymod2 = MyModuleParamNElement(in_channels=15)
151
self.verify_mm_relu_mods(mymod, mymod2)
154
if __name__ == '__main__':