pytorch

Форк
0
/
test_onnxscript_runtime.py 
127 строк · 4.3 Кб
1
# Owner(s): ["module: onnx"]
2

3
"""Test the support on onnxscript in PyTorch-ONNX converter with onnxruntime."""
4
from typing import List
5

6
import onnx_test_common
7
import onnxscript
8
import torch
9
from onnxscript.onnx_types import FLOAT
10
from torch.onnx._internal import jit_utils
11
from torch.testing._internal import common_utils
12

13

14
class TestONNXScriptRuntime(onnx_test_common._TestONNXRuntime):
15
    # opset version is
16
    # 1. local function is supported after opset 15
17
    # 2. onnx-script requires users to determine opset in local function
18
    opset_version = 15
19

20
    def test_selu_from_onnxscript_example(self):
21
        x = torch.randn(1, 2, 3, 4, requires_grad=True)
22
        model = torch.nn.SELU()
23

24
        from onnxscript.onnx_opset import opset15 as op
25

26
        custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1)
27

28
        @onnxscript.script(custom_opset)
29
        def Selu(
30
            X,
31
        ):
32
            # default value is not supported by onnxscript
33
            alpha = 1.67326  # auto wrapped as Constants
34
            gamma = 1.0507
35
            alphaX = op.CastLike(alpha, X)
36
            gammaX = op.CastLike(gamma, X)
37
            neg = gammaX * (alphaX * op.Exp(X) - alphaX)
38
            pos = gammaX * X
39
            zero = op.CastLike(0, X)
40
            return op.Where(X <= zero, neg, pos)
41

42
        def custom_selu(g: jit_utils.GraphContext, X):
43
            return g.onnxscript_op(Selu, X).setType(X.type())
44

45
        torch.onnx.register_custom_op_symbolic(
46
            symbolic_name="aten::selu",
47
            symbolic_fn=custom_selu,
48
            opset_version=self.opset_version,
49
        )
50
        self.run_test(model, x)
51

52
    def test_layer_norm(self):
53
        x = torch.randn(2, 3)
54
        y = torch.randn(2, 3)
55
        z = torch.randn(2, 3)
56

57
        class N(torch.nn.Module):
58
            def __init__(self, prob):
59
                super().__init__()
60
                self.dropout = torch.nn.Dropout(prob)
61

62
            def forward(self, x):
63
                return self.dropout(x)
64

65
        class M(torch.nn.Module):
66
            def __init__(self, num_layers):
67
                super().__init__()
68
                self.num_layers = num_layers
69
                self.lns = torch.nn.ModuleList(
70
                    [torch.nn.LayerNorm(3, eps=i) for i in range(num_layers)]
71
                )
72
                self.celu1 = torch.nn.CELU(1.0)
73
                self.celu2 = torch.nn.CELU(2.0)
74
                self.dropout = N(0.5)
75

76
            def forward(self, x, y, z):
77
                res1 = self.celu1(x)
78
                res2 = self.celu2(y)
79
                for ln in self.lns:
80
                    z = ln(z)
81
                return res1 + res2, self.dropout(z)
82

83
        model = M(3)
84

85
        from onnxscript.onnx_opset import opset15 as op
86

87
        custom_opset = onnxscript.values.Opset(domain="onnxscript", version=1)
88

89
        @onnxscript.script(custom_opset)
90
        def layer_norm(
91
            X, axes: List[int], weight: FLOAT[...], bias: FLOAT[...], eps: float
92
        ):
93
            mean = op.ReduceMean(X, axes=axes)
94
            D = X - mean  # op.Sub(X, mean)
95
            DD = D * D  # op.Mul(D, D)
96
            var = op.ReduceMean(DD, axes=axes)
97
            vareps = var + eps  # op.Add(var, eps)
98
            stddev = op.Sqrt(vareps)
99
            invstddev = op.Reciprocal(stddev)
100
            normalized = D * invstddev  # op.Mul(D, invstddev)
101
            normalizedw = op.CastLike(
102
                normalized, weight
103
            )  # Type issue if missing this Op
104
            normalizedscaled = normalizedw * weight  # op.Mul(normalized, weight)
105
            return normalizedscaled + bias
106

107
        @torch.onnx.symbolic_helper.parse_args("v", "is", "v", "v", "f", "none")
108
        def custom_layer_norm(
109
            g, input, normalized_shape, weight, bias, eps, cudnn_enable
110
        ):
111
            # comprehension is not supported by onnxscript
112
            axes = [-i for i in range(len(normalized_shape), 0, -1)]
113
            return g.onnxscript_op(
114
                layer_norm, input, weight, bias, axes_i=axes, eps_f=eps
115
            ).setType(input.type())
116

117
        torch.onnx.register_custom_op_symbolic(
118
            symbolic_name="aten::layer_norm",
119
            symbolic_fn=custom_layer_norm,
120
            opset_version=self.opset_version,
121
        )
122

123
        self.run_test(model, (x, y, z))
124

125

126
if __name__ == "__main__":
127
    common_utils.run_tests()
128

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.