onnxruntime
113 строк · 4.1 Кб
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4import assert from 'assert';5import { InferenceSession, Tensor } from 'onnxruntime-common';6import * as path from 'path';7
8import { assertDataEqual, TEST_DATA_ROOT } from '../test-utils';9
10const MODEL_TEST_TYPES_CASES: Array<{11model: string;12type: Tensor.Type;13input0: Tensor.DataType;14expectedOutput0: Tensor.DataType;15}> = [16{17model: path.join(TEST_DATA_ROOT, 'test_types_bool.onnx'),18type: 'bool',19input0: Uint8Array.from([1, 0, 0, 1, 0]),20expectedOutput0: Uint8Array.from([1, 0, 0, 1, 0]),21},22{23model: path.join(TEST_DATA_ROOT, 'test_types_double.onnx'),24type: 'float64',25input0: Float64Array.from([1.0, 2.0, 3.0, 4.0, 5.0]),26expectedOutput0: Float64Array.from([1.0, 2.0, 3.0, 4.0, 5.0]),27},28{29model: path.join(TEST_DATA_ROOT, 'test_types_float.onnx'),30type: 'float32',31input0: Float32Array.from([1.0, 2.0, 3.0, 4.0, 5.0]),32expectedOutput0: Float32Array.from([1.0, 2.0, 3.0, 4.0, 5.0]),33},34{35model: path.join(TEST_DATA_ROOT, 'test_types_int8.onnx'),36type: 'int8',37input0: Int8Array.from([1, -2, 3, 4, -5]),38expectedOutput0: Int8Array.from([1, -2, 3, 4, -5]),39},40{41model: path.join(TEST_DATA_ROOT, 'test_types_int16.onnx'),42type: 'int16',43input0: Int16Array.from([1, -2, 3, 4, -5]),44expectedOutput0: Int16Array.from([1, -2, 3, 4, -5]),45},46{47model: path.join(TEST_DATA_ROOT, 'test_types_int32.onnx'),48type: 'int32',49input0: Int32Array.from([1, -2, 3, 4, -5]),50expectedOutput0: Int32Array.from([1, -2, 3, 4, -5]),51},52{53model: path.join(TEST_DATA_ROOT, 'test_types_int64.onnx'),54type: 'int64',55input0: BigInt64Array.from([BigInt(1), BigInt(-2), BigInt(3), BigInt(4), BigInt(-5)]),56expectedOutput0: BigInt64Array.from([BigInt(1), BigInt(-2), BigInt(3), BigInt(4), BigInt(-5)]),57},58{59model: path.join(TEST_DATA_ROOT, 'test_types_string.onnx'),60type: 'string',61input0: ['a', 'b', 'c', 'd', 'e'],62expectedOutput0: ['a', 'b', 'c', 'd', 'e'],63},64{65model: path.join(TEST_DATA_ROOT, 'test_types_uint8.onnx'),66type: 'uint8',67input0: Uint8Array.from([1, 2, 3, 4, 5]),68expectedOutput0: Uint8Array.from([1, 2, 3, 4, 5]),69},70{71model: path.join(TEST_DATA_ROOT, 'test_types_uint16.onnx'),72type: 'uint16',73input0: Uint16Array.from([1, 2, 3, 4, 5]),74expectedOutput0: Uint16Array.from([1, 2, 3, 4, 5]),75},76{77model: path.join(TEST_DATA_ROOT, 'test_types_uint32.onnx'),78type: 'uint32',79input0: Uint32Array.from([1, 2, 3, 4, 5]),80expectedOutput0: Uint32Array.from([1, 2, 3, 4, 5]),81},82{83model: path.join(TEST_DATA_ROOT, 'test_types_uint64.onnx'),84type: 'uint64',85input0: BigUint64Array.from([BigInt(1), BigInt(2), BigInt(3), BigInt(4), BigInt(5)]),86expectedOutput0: BigUint64Array.from([BigInt(1), BigInt(2), BigInt(3), BigInt(4), BigInt(5)]),87},88];89
90describe('E2E Tests - simple E2E tests', () => {91MODEL_TEST_TYPES_CASES.forEach((testCase) => {92it(`${testCase.model}`, async () => {93const session = await InferenceSession.create(testCase.model);94const output = await session.run({ input: new Tensor(testCase.type, testCase.input0, [1, 5]) });95assert(Object.prototype.hasOwnProperty.call(output, 'output'), "'output' should be in the result object.");96assert(output.output instanceof Tensor, 'result[output] should be a Tensor object.');97assert.strictEqual(output.output.size, 5, `output size expected 5, got ${output.output.size}.`);98assert.strictEqual(99output.output.type,100testCase.type,101`tensor type expected ${testCase.type}, got ${output.output.type}.`,102);103assert.strictEqual(104Object.getPrototypeOf(output.output.data),105Object.getPrototypeOf(testCase.expectedOutput0),106`tensor data expected ${Object.getPrototypeOf(testCase.expectedOutput0).constructor.name}, got ${107Object.getPrototypeOf(output.output.data).constructor.name108}`,109);110assertDataEqual(testCase.type, output.output.data, testCase.expectedOutput0);111});112});113});114