6
from torch._custom_op import impl as custom_op
7
from torch.onnx._internal.fx.passes import _utils as pass_utils
8
from torch.testing._internal import common_utils
11
class TestFxPasses(common_utils.TestCase):
12
def test_set_node_name_correctly_renames_when_new_name_collides_recursively(self):
19
gm, _ = torch._dynamo.export(func)(x, y, z)
25
nodes = list(gm.graph.nodes)
26
for i, node in enumerate(nodes[1:]):
30
node.name = f"{base_name}.{i}"
33
name_to_node = {node.name: node for node in gm.graph.nodes}
34
pass_utils.set_node_name(nodes[0], base_name, name_to_node)
35
assert nodes[0].name == base_name, f"Expected {base_name}, got {nodes[0].name}"
36
assert len({node.name for node in nodes}) == len(
38
), f"Expected all names to be unique, got {nodes}"
40
def test_set_node_name_succeeds_when_no_name_collisions(self):
47
gm, _ = torch._dynamo.export(func)(x, y, z)
51
new_name = "some_tensor"
52
nodes = list(gm.graph.nodes)
53
name_to_node = {node.name: node for node in nodes}
54
pass_utils.set_node_name(nodes[1], new_name, name_to_node)
55
assert nodes[1].name == new_name, f"Expected {new_name}, got {nodes[0].name}"
56
assert len({node.name for node in nodes}) == len(
58
), f"Expected all names to be unique, got {nodes}"
60
def test_onnx_dynamo_export_raises_when_model_contains_unsupported_fx_nodes(self):
61
@custom_op.custom_op("mylibrary::foo_op")
62
def foo_op(x: torch.Tensor) -> torch.Tensor:
65
@custom_op.custom_op("mylibrary::bar_op")
66
def bar_op(x: torch.Tensor) -> torch.Tensor:
69
@foo_op.impl_abstract()
70
def foo_op_impl_abstract(x):
71
return torch.empty_like(x)
77
@bar_op.impl_abstract()
78
def bar_op_impl_abstract(x):
79
return torch.empty_like(x)
85
torch._dynamo.allow_in_graph(foo_op)
86
torch._dynamo.allow_in_graph(bar_op)
89
return foo_op(x) + bar_op(y) + z
94
with self.assertRaises(torch.onnx.OnnxExporterError) as ctx:
95
torch.onnx.dynamo_export(func, x, y, z)
96
inner_exception = ctx.exception.__cause__
99
r"Unsupported FX nodes.*mylibrary\.foo_op.*mylibrary\.bar_op",
102
torch._dynamo.reset()
105
@common_utils.instantiate_parametrized_tests
106
class TestModularizePass(common_utils.TestCase):
107
@common_utils.parametrize(
108
"is_exported_program",
110
common_utils.subtest(
112
name="exported_program",
114
common_utils.subtest(
120
def test_modularize_pass_succeeds_when_submodule_output_is_unused(
121
self, is_exported_program
131
class TestModule(torch.nn.Module):
134
self.unused_relu = torch.nn.ReLU()
135
self.used_gelu = torch.nn.GELU()
137
def forward(self, x, y):
138
result = self.used_gelu(x + y)
139
unused_relu_result = self.unused_relu(x)
142
if is_exported_program:
143
model = torch.export.export(
144
TestModule(), args=(torch.randn(3), torch.randn(3))
149
onnx_program = torch.onnx.dynamo_export(model, torch.randn(3), torch.randn(3))
150
model_proto = onnx_program.model_proto
151
function_proto_names = [function.name for function in model_proto.functions]
153
"torch_nn_modules_activation_GELU_used_gelu_1", function_proto_names
155
self.assertFalse(any("ReLU" in name for name in function_proto_names))
157
@common_utils.parametrize(
158
"is_exported_program",
160
common_utils.subtest(
162
name="exported_program",
164
common_utils.subtest(
170
def test_modularize_pass_succeeds_when_a_submodule_is_called_multiple_times(
171
self, is_exported_program
173
class TestModule(torch.nn.Module):
176
self.relu = torch.nn.ReLU()
178
def forward(self, x, y):
185
if is_exported_program:
186
model = torch.export.export(
187
TestModule(), args=(torch.randn(3), torch.randn(3))
192
onnx_program = torch.onnx.dynamo_export(model, torch.randn(3), torch.randn(3))
193
model_proto = onnx_program.model_proto
194
function_proto_names = [function.name for function in model_proto.functions]
195
self.assertIn("torch_nn_modules_activation_ReLU_relu_1", function_proto_names)
196
self.assertIn("torch_nn_modules_activation_ReLU_relu_2", function_proto_names)
198
@common_utils.parametrize(
199
"is_exported_program",
201
common_utils.subtest(
203
name="exported_program",
205
common_utils.subtest(
211
def test_modularize_pass_succeeds_when_a_submodule_is_called_from_multiple_layers(
212
self, is_exported_program
215
class InnerModule(torch.nn.Module):
218
self.relu = torch.nn.ReLU()
220
def forward(self, x):
223
class TestModule(torch.nn.Module):
226
self.inner_module = InnerModule()
228
def forward(self, x, y):
230
out = self.inner_module(out)
232
out = self.inner_module.relu(out)
235
if is_exported_program:
236
model = torch.export.export(
237
TestModule(), args=(torch.randn(3), torch.randn(3))
242
onnx_program = torch.onnx.dynamo_export(model, torch.randn(3), torch.randn(3))
243
model_proto = onnx_program.model_proto
244
function_proto_names = [function.name for function in model_proto.functions]
246
"torch_nn_modules_activation_ReLU_inner_module_relu_1", function_proto_names
249
"torch_nn_modules_activation_ReLU_inner_module_relu_2", function_proto_names
254
any("InnerModule_inner_module_1" in name for name in function_proto_names)
258
if __name__ == "__main__":
259
common_utils.run_tests()