pytorch

Форк
0
/
test_fx_to_onnx_decomp_skip.py 
53 строки · 2.0 Кб
1
# Owner(s): ["module: onnx"]
2
from __future__ import annotations
3

4
import onnx
5
import onnx.inliner
6
import pytorch_test_common
7

8
import torch
9
from torch.testing._internal import common_utils
10

11

12
def assert_op_in_onnx_model(model: onnx.ModelProto, op_type: str):
13
    inlined = onnx.inliner.inline_local_functions(model)
14
    for node in inlined.graph.node:
15
        if node.op_type == op_type:
16
            return
17
    raise AssertionError(f"Op {op_type} not found in model")
18

19

20
class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase):
21
    def test_upsample_bilinear2d(self):
22
        class TestModel(torch.nn.Module):
23
            def __init__(self):
24
                super().__init__()
25
                self.upsample = torch.nn.Upsample(scale_factor=2, mode="bilinear")
26

27
            def forward(self, x):
28
                return self.upsample(x)
29

30
        onnx_program = torch.onnx.dynamo_export(TestModel(), torch.randn(1, 1, 2, 2))
31
        # If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph.
32
        assert_op_in_onnx_model(onnx_program.model_proto, "Resize")
33

34
    def test_upsample_bilinear2d_output_size(self):
35
        def func(x: torch.Tensor):
36
            return torch.nn.functional.interpolate(x, size=(4, 4), mode="bilinear")
37

38
        onnx_program = torch.onnx.dynamo_export(func, torch.randn(1, 1, 2, 2))
39
        # If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph.
40
        assert_op_in_onnx_model(onnx_program.model_proto, "Resize")
41

42
    def test_instance_norm(self):
43
        def func(x: torch.Tensor):
44
            return torch.nn.functional.instance_norm(x)
45

46
        onnx_program = torch.onnx.dynamo_export(func, torch.randn(1, 1, 2, 2))
47
        # If decomposition is skipped, the model will contain an InstanceNormalization op
48
        # instead of BatchNormalization op w/ training=True.
49
        assert_op_in_onnx_model(onnx_program.model_proto, "InstanceNormalization")
50

51

52
if __name__ == "__main__":
53
    common_utils.run_tests()
54

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.