pytorch

Форк
0
/
test_onnxscript_runtime.py 
129 строк · 4.3 Кб
1
# Owner(s): ["module: onnx"]
2

3
"""Test the support on onnxscript in PyTorch-ONNX converter with onnxruntime."""
4

5
from typing import List
6

7
import onnx_test_common
8
import onnxscript
9
from onnxscript.onnx_types import FLOAT
10

11
import torch
12
from torch.onnx._internal import jit_utils
13
from torch.testing._internal import common_utils
14

15

16
class TestONNXScriptRuntime(onnx_test_common._TestONNXRuntime):
17
    # opset version is
18
    # 1. local function is supported after opset 15
19
    # 2. onnx-script requires users to determine opset in local function
20
    opset_version = 15
21

22
    def test_selu_from_onnxscript_example(self):
23
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
24
        model = torch.nn.SELU()
25

26
        from onnxscript.onnx_opset import opset15 as op
27

28
        custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1)
29

30
        @onnxscript.script(custom_opset)
31
        def Selu(
32
            X,
33
        ):
34
            # default value is not supported by onnxscript
35
            alpha = 1.67326  # auto wrapped as Constants
36
            gamma = 1.0507
37
            alphaX = op.CastLike(alpha, X)
38
            gammaX = op.CastLike(gamma, X)
39
            neg = gammaX * (alphaX * op.Exp(X) - alphaX)
40
            pos = gammaX * X
41
            zero = op.CastLike(0, X)
42
            return op.Where(X <= zero, neg, pos)
43

44
        def custom_selu(g: jit_utils.GraphContext, X):
45
            return g.onnxscript_op(Selu, X).setType(X.type())
46

47
        torch.onnx.register_custom_op_symbolic(
48
            symbolic_name="aten::selu",
49
            symbolic_fn=custom_selu,
50
            opset_version=self.opset_version,
51
        )
52
        self.run_test(model, x)
53

54
    def test_layer_norm(self):
55
        x = torch.randn(2, 3)
56
        y = torch.randn(2, 3)
57
        z = torch.randn(2, 3)
58

59
        class N(torch.nn.Module):
60
            def __init__(self, prob):
61
                super().__init__()
62
                self.dropout = torch.nn.Dropout(prob)
63

64
            def forward(self, x):
65
                return self.dropout(x)
66

67
        class M(torch.nn.Module):
68
            def __init__(self, num_layers):
69
                super().__init__()
70
                self.num_layers = num_layers
71
                self.lns = torch.nn.ModuleList(
72
                    [torch.nn.LayerNorm(3, eps=i) for i in range(num_layers)]
73
                )
74
                self.celu1 = torch.nn.CELU(1.0)
75
                self.celu2 = torch.nn.CELU(2.0)
76
                self.dropout = N(0.5)
77

78
            def forward(self, x, y, z):
79
                res1 = self.celu1(x)
80
                res2 = self.celu2(y)
81
                for ln in self.lns:
82
                    z = ln(z)
83
                return res1 + res2, self.dropout(z)
84

85
        model = M(3)
86

87
        from onnxscript.onnx_opset import opset15 as op
88

89
        custom_opset = onnxscript.values.Opset(domain="onnxscript", version=1)
90

91
        @onnxscript.script(custom_opset)
92
        def layer_norm(
93
            X, axes: List[int], weight: FLOAT[...], bias: FLOAT[...], eps: float
94
        ):
95
            mean = op.ReduceMean(X, axes=axes)
96
            D = X - mean  # op.Sub(X, mean)
97
            DD = D * D  # op.Mul(D, D)
98
            var = op.ReduceMean(DD, axes=axes)
99
            vareps = var + eps  # op.Add(var, eps)
100
            stddev = op.Sqrt(vareps)
101
            invstddev = op.Reciprocal(stddev)
102
            normalized = D * invstddev  # op.Mul(D, invstddev)
103
            normalizedw = op.CastLike(
104
                normalized, weight
105
            )  # Type issue if missing this Op
106
            normalizedscaled = normalizedw * weight  # op.Mul(normalized, weight)
107
            return normalizedscaled + bias
108

109
        @torch.onnx.symbolic_helper.parse_args("v", "is", "v", "v", "f", "none")
110
        def custom_layer_norm(
111
            g, input, normalized_shape, weight, bias, eps, cudnn_enable
112
        ):
113
            # comprehension is not supported by onnxscript
114
            axes = [-i for i in range(len(normalized_shape), 0, -1)]
115
            return g.onnxscript_op(
116
                layer_norm, input, weight, bias, axes_i=axes, eps_f=eps
117
            ).setType(input.type())
118

119
        torch.onnx.register_custom_op_symbolic(
120
            symbolic_name="aten::layer_norm",
121
            symbolic_fn=custom_layer_norm,
122
            opset_version=self.opset_version,
123
        )
124

125
        self.run_test(model, (x, y, z))
126

127

128
if __name__ == "__main__":
129
    common_utils.run_tests()
130

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.