onnxruntime
42 строки · 1.5 Кб
1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License.
3
4import unittest
5
6import torch
7
8from ..pytorch_export_helpers import infer_input_info
9
10# example usage from <ort root>/tools/python
11# python -m unittest util/test/test_pytorch_export_helpers.py
12# NOTE: at least on Windows you must use that as the working directory for all the imports to be happy
13
14
15class TestModel(torch.nn.Module):
16def __init__(self, D_in, H, D_out):
17super().__init__()
18self.linear1 = torch.nn.Linear(D_in, H)
19self.linear2 = torch.nn.Linear(H, D_out)
20
21def forward(self, x, min=0, max=1):
22step1 = self.linear1(x).clamp(min=min, max=max)
23step2 = self.linear2(step1)
24return step2
25
26
27class TestInferInputs(unittest.TestCase):
28@classmethod
29def setUpClass(cls):
30cls._model = TestModel(1000, 100, 10)
31cls._input = torch.randn(1, 1000)
32
33def test_positional(self):
34# test we can infer the input names from the forward method when positional args are used
35input_names, inputs_as_tuple = infer_input_info(self._model, self._input, 0, 1)
36self.assertEqual(input_names, ["x", "min", "max"])
37
38def test_keywords(self):
39# test that we sort keyword args and the inputs to match the module
40input_names, inputs_as_tuple = infer_input_info(self._model, self._input, max=1, min=0)
41self.assertEqual(input_names, ["x", "min", "max"])
42self.assertEqual(inputs_as_tuple, (self._input, 0, 1))
43