1
# Owner(s): ["module: onnx"]
3
"""Test the support on onnxscript in PyTorch-ONNX converter with onnxruntime."""
9
from onnxscript.onnx_types import FLOAT
12
from torch.onnx._internal import jit_utils
13
from torch.testing._internal import common_utils
16
class TestONNXScriptRuntime(onnx_test_common._TestONNXRuntime):
18
# 1. local function is supported after opset 15
19
# 2. onnx-script requires users to determine opset in local function
22
def test_selu_from_onnxscript_example(self):
23
x = torch.randn(1, 2, 3, 4, requires_grad=True)
24
model = torch.nn.SELU()
26
from onnxscript.onnx_opset import opset15 as op
28
custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1)
30
@onnxscript.script(custom_opset)
34
# default value is not supported by onnxscript
35
alpha = 1.67326 # auto wrapped as Constants
37
alphaX = op.CastLike(alpha, X)
38
gammaX = op.CastLike(gamma, X)
39
neg = gammaX * (alphaX * op.Exp(X) - alphaX)
41
zero = op.CastLike(0, X)
42
return op.Where(X <= zero, neg, pos)
44
def custom_selu(g: jit_utils.GraphContext, X):
45
return g.onnxscript_op(Selu, X).setType(X.type())
47
torch.onnx.register_custom_op_symbolic(
48
symbolic_name="aten::selu",
49
symbolic_fn=custom_selu,
50
opset_version=self.opset_version,
52
self.run_test(model, x)
54
def test_layer_norm(self):
59
class N(torch.nn.Module):
60
def __init__(self, prob):
62
self.dropout = torch.nn.Dropout(prob)
65
return self.dropout(x)
67
class M(torch.nn.Module):
68
def __init__(self, num_layers):
70
self.num_layers = num_layers
71
self.lns = torch.nn.ModuleList(
72
[torch.nn.LayerNorm(3, eps=i) for i in range(num_layers)]
74
self.celu1 = torch.nn.CELU(1.0)
75
self.celu2 = torch.nn.CELU(2.0)
78
def forward(self, x, y, z):
83
return res1 + res2, self.dropout(z)
87
from onnxscript.onnx_opset import opset15 as op
89
custom_opset = onnxscript.values.Opset(domain="onnxscript", version=1)
91
@onnxscript.script(custom_opset)
93
X, axes: List[int], weight: FLOAT[...], bias: FLOAT[...], eps: float
95
mean = op.ReduceMean(X, axes=axes)
96
D = X - mean # op.Sub(X, mean)
97
DD = D * D # op.Mul(D, D)
98
var = op.ReduceMean(DD, axes=axes)
99
vareps = var + eps # op.Add(var, eps)
100
stddev = op.Sqrt(vareps)
101
invstddev = op.Reciprocal(stddev)
102
normalized = D * invstddev # op.Mul(D, invstddev)
103
normalizedw = op.CastLike(
105
) # Type issue if missing this Op
106
normalizedscaled = normalizedw * weight # op.Mul(normalized, weight)
107
return normalizedscaled + bias
109
@torch.onnx.symbolic_helper.parse_args("v", "is", "v", "v", "f", "none")
110
def custom_layer_norm(
111
g, input, normalized_shape, weight, bias, eps, cudnn_enable
113
# comprehension is not supported by onnxscript
114
axes = [-i for i in range(len(normalized_shape), 0, -1)]
115
return g.onnxscript_op(
116
layer_norm, input, weight, bias, axes_i=axes, eps_f=eps
117
).setType(input.type())
119
torch.onnx.register_custom_op_symbolic(
120
symbolic_name="aten::layer_norm",
121
symbolic_fn=custom_layer_norm,
122
opset_version=self.opset_version,
125
self.run_test(model, (x, y, z))
128
if __name__ == "__main__":
129
common_utils.run_tests()