onnxruntime
129 строк · 5.1 Кб
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4import * as fs from 'fs-extra';
5import { InferenceSession, Tensor } from 'onnxruntime-common';
6import * as path from 'path';
7
8import { assertTensorEqual, atol, loadTensorFromFile, rtol, shouldSkipModel } from './test-utils';
9
10export function run(testDataRoot: string): void {
11const opsets = fs.readdirSync(testDataRoot);
12for (const opset of opsets) {
13const testDataFolder = path.join(testDataRoot, opset);
14const testDataFolderStat = fs.lstatSync(testDataFolder);
15if (testDataFolderStat.isDirectory()) {
16const models = fs.readdirSync(testDataFolder);
17
18for (const model of models) {
19// read each model folders
20const modelFolder = path.join(testDataFolder, model);
21let modelPath: string;
22const modelTestCases: Array<[Array<Tensor | undefined>, Array<Tensor | undefined>]> = [];
23for (const currentFile of fs.readdirSync(modelFolder)) {
24const currentPath = path.join(modelFolder, currentFile);
25const stat = fs.lstatSync(currentPath);
26if (stat.isFile()) {
27const ext = path.extname(currentPath);
28if (ext.toLowerCase() === '.onnx') {
29modelPath = currentPath;
30}
31} else if (stat.isDirectory()) {
32const inputs: Array<Tensor | undefined> = [];
33const outputs: Array<Tensor | undefined> = [];
34for (const dataFile of fs.readdirSync(currentPath)) {
35const dataFileFullPath = path.join(currentPath, dataFile);
36const ext = path.extname(dataFile);
37
38if (ext.toLowerCase() === '.pb') {
39let tensor: Tensor | undefined;
40try {
41tensor = loadTensorFromFile(dataFileFullPath);
42} catch (e) {
43console.warn(`[${model}] Failed to load test data: ${e.message}`);
44}
45
46if (dataFile.indexOf('input') !== -1) {
47inputs.push(tensor);
48} else if (dataFile.indexOf('output') !== -1) {
49outputs.push(tensor);
50}
51}
52}
53modelTestCases.push([inputs, outputs]);
54}
55}
56
57// add cases
58describe(`${opset}/${model}`, () => {
59let session: InferenceSession | null = null;
60let skipModel = shouldSkipModel(model, opset, ['cpu']);
61if (!skipModel) {
62before(async () => {
63try {
64session = await InferenceSession.create(modelPath);
65} catch (e) {
66// By default ort allows models with opsets from an official onnx release only. If it encounters
67// a model with opset > than released opset, ValidateOpsetForDomain throws an error and model load
68// fails. Since this is by design such a failure is acceptable in the context of this test. Therefore we
69// simply skip this test. Setting env variable ALLOW_RELEASED_ONNX_OPSET_ONLY=0 allows loading a model
70// with opset > released onnx opset.
71if (
72process.env.ALLOW_RELEASED_ONNX_OPSET_ONLY !== '0' &&
73e.message.includes('ValidateOpsetForDomain')
74) {
75session = null;
76console.log(`Skipping ${model}. To run this test set env variable ALLOW_RELEASED_ONNX_OPSET_ONLY=0`);
77skipModel = true;
78} else {
79throw e;
80}
81}
82});
83} else {
84console.log(`[test-runner] skipped: ${model}`);
85}
86
87for (let i = 0; i < modelTestCases.length; i++) {
88const testCase = modelTestCases[i];
89const inputs = testCase[0];
90const expectedOutputs = testCase[1];
91if (!skipModel && !inputs.some((t) => t === undefined) && !expectedOutputs.some((t) => t === undefined)) {
92it(`case${i}`, async () => {
93if (skipModel) {
94return;
95}
96
97if (session !== null) {
98const feeds: Record<string, Tensor> = {};
99if (inputs.length !== session.inputNames.length) {
100throw new RangeError('input length does not match name list');
101}
102for (let i = 0; i < inputs.length; i++) {
103feeds[session.inputNames[i]] = inputs[i]!;
104}
105const outputs = await session.run(feeds);
106
107let j = 0;
108for (const name of session.outputNames) {
109assertTensorEqual(outputs[name], expectedOutputs[j++]!, atol(model), rtol(model));
110}
111} else {
112throw new TypeError('session is null');
113}
114});
115}
116}
117
118if (!skipModel) {
119after(async () => {
120if (session !== null) {
121await session.release();
122}
123});
124}
125});
126}
127}
128}
129}
130