onnxruntime

Форк
0
/
simple-e2e-tests.ts 
113 строк · 4.1 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import assert from 'assert';
5
import { InferenceSession, Tensor } from 'onnxruntime-common';
6
import * as path from 'path';
7

8
import { assertDataEqual, TEST_DATA_ROOT } from '../test-utils';
9

10
const MODEL_TEST_TYPES_CASES: Array<{
11
  model: string;
12
  type: Tensor.Type;
13
  input0: Tensor.DataType;
14
  expectedOutput0: Tensor.DataType;
15
}> = [
16
  {
17
    model: path.join(TEST_DATA_ROOT, 'test_types_bool.onnx'),
18
    type: 'bool',
19
    input0: Uint8Array.from([1, 0, 0, 1, 0]),
20
    expectedOutput0: Uint8Array.from([1, 0, 0, 1, 0]),
21
  },
22
  {
23
    model: path.join(TEST_DATA_ROOT, 'test_types_double.onnx'),
24
    type: 'float64',
25
    input0: Float64Array.from([1.0, 2.0, 3.0, 4.0, 5.0]),
26
    expectedOutput0: Float64Array.from([1.0, 2.0, 3.0, 4.0, 5.0]),
27
  },
28
  {
29
    model: path.join(TEST_DATA_ROOT, 'test_types_float.onnx'),
30
    type: 'float32',
31
    input0: Float32Array.from([1.0, 2.0, 3.0, 4.0, 5.0]),
32
    expectedOutput0: Float32Array.from([1.0, 2.0, 3.0, 4.0, 5.0]),
33
  },
34
  {
35
    model: path.join(TEST_DATA_ROOT, 'test_types_int8.onnx'),
36
    type: 'int8',
37
    input0: Int8Array.from([1, -2, 3, 4, -5]),
38
    expectedOutput0: Int8Array.from([1, -2, 3, 4, -5]),
39
  },
40
  {
41
    model: path.join(TEST_DATA_ROOT, 'test_types_int16.onnx'),
42
    type: 'int16',
43
    input0: Int16Array.from([1, -2, 3, 4, -5]),
44
    expectedOutput0: Int16Array.from([1, -2, 3, 4, -5]),
45
  },
46
  {
47
    model: path.join(TEST_DATA_ROOT, 'test_types_int32.onnx'),
48
    type: 'int32',
49
    input0: Int32Array.from([1, -2, 3, 4, -5]),
50
    expectedOutput0: Int32Array.from([1, -2, 3, 4, -5]),
51
  },
52
  {
53
    model: path.join(TEST_DATA_ROOT, 'test_types_int64.onnx'),
54
    type: 'int64',
55
    input0: BigInt64Array.from([BigInt(1), BigInt(-2), BigInt(3), BigInt(4), BigInt(-5)]),
56
    expectedOutput0: BigInt64Array.from([BigInt(1), BigInt(-2), BigInt(3), BigInt(4), BigInt(-5)]),
57
  },
58
  {
59
    model: path.join(TEST_DATA_ROOT, 'test_types_string.onnx'),
60
    type: 'string',
61
    input0: ['a', 'b', 'c', 'd', 'e'],
62
    expectedOutput0: ['a', 'b', 'c', 'd', 'e'],
63
  },
64
  {
65
    model: path.join(TEST_DATA_ROOT, 'test_types_uint8.onnx'),
66
    type: 'uint8',
67
    input0: Uint8Array.from([1, 2, 3, 4, 5]),
68
    expectedOutput0: Uint8Array.from([1, 2, 3, 4, 5]),
69
  },
70
  {
71
    model: path.join(TEST_DATA_ROOT, 'test_types_uint16.onnx'),
72
    type: 'uint16',
73
    input0: Uint16Array.from([1, 2, 3, 4, 5]),
74
    expectedOutput0: Uint16Array.from([1, 2, 3, 4, 5]),
75
  },
76
  {
77
    model: path.join(TEST_DATA_ROOT, 'test_types_uint32.onnx'),
78
    type: 'uint32',
79
    input0: Uint32Array.from([1, 2, 3, 4, 5]),
80
    expectedOutput0: Uint32Array.from([1, 2, 3, 4, 5]),
81
  },
82
  {
83
    model: path.join(TEST_DATA_ROOT, 'test_types_uint64.onnx'),
84
    type: 'uint64',
85
    input0: BigUint64Array.from([BigInt(1), BigInt(2), BigInt(3), BigInt(4), BigInt(5)]),
86
    expectedOutput0: BigUint64Array.from([BigInt(1), BigInt(2), BigInt(3), BigInt(4), BigInt(5)]),
87
  },
88
];
89

90
describe('E2E Tests - simple E2E tests', () => {
91
  MODEL_TEST_TYPES_CASES.forEach((testCase) => {
92
    it(`${testCase.model}`, async () => {
93
      const session = await InferenceSession.create(testCase.model);
94
      const output = await session.run({ input: new Tensor(testCase.type, testCase.input0, [1, 5]) });
95
      assert(Object.prototype.hasOwnProperty.call(output, 'output'), "'output' should be in the result object.");
96
      assert(output.output instanceof Tensor, 'result[output] should be a Tensor object.');
97
      assert.strictEqual(output.output.size, 5, `output size expected 5, got ${output.output.size}.`);
98
      assert.strictEqual(
99
        output.output.type,
100
        testCase.type,
101
        `tensor type expected ${testCase.type}, got ${output.output.type}.`,
102
      );
103
      assert.strictEqual(
104
        Object.getPrototypeOf(output.output.data),
105
        Object.getPrototypeOf(testCase.expectedOutput0),
106
        `tensor data expected ${Object.getPrototypeOf(testCase.expectedOutput0).constructor.name}, got ${
107
          Object.getPrototypeOf(output.output.data).constructor.name
108
        }`,
109
      );
110
      assertDataEqual(testCase.type, output.output.data, testCase.expectedOutput0);
111
    });
112
  });
113
});
114

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.