onnxruntime

Форк
0
/
test-runner.ts 
129 строк · 5.1 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import * as fs from 'fs-extra';
5
import { InferenceSession, Tensor } from 'onnxruntime-common';
6
import * as path from 'path';
7

8
import { assertTensorEqual, atol, loadTensorFromFile, rtol, shouldSkipModel } from './test-utils';
9

10
export function run(testDataRoot: string): void {
11
  const opsets = fs.readdirSync(testDataRoot);
12
  for (const opset of opsets) {
13
    const testDataFolder = path.join(testDataRoot, opset);
14
    const testDataFolderStat = fs.lstatSync(testDataFolder);
15
    if (testDataFolderStat.isDirectory()) {
16
      const models = fs.readdirSync(testDataFolder);
17

18
      for (const model of models) {
19
        // read each model folders
20
        const modelFolder = path.join(testDataFolder, model);
21
        let modelPath: string;
22
        const modelTestCases: Array<[Array<Tensor | undefined>, Array<Tensor | undefined>]> = [];
23
        for (const currentFile of fs.readdirSync(modelFolder)) {
24
          const currentPath = path.join(modelFolder, currentFile);
25
          const stat = fs.lstatSync(currentPath);
26
          if (stat.isFile()) {
27
            const ext = path.extname(currentPath);
28
            if (ext.toLowerCase() === '.onnx') {
29
              modelPath = currentPath;
30
            }
31
          } else if (stat.isDirectory()) {
32
            const inputs: Array<Tensor | undefined> = [];
33
            const outputs: Array<Tensor | undefined> = [];
34
            for (const dataFile of fs.readdirSync(currentPath)) {
35
              const dataFileFullPath = path.join(currentPath, dataFile);
36
              const ext = path.extname(dataFile);
37

38
              if (ext.toLowerCase() === '.pb') {
39
                let tensor: Tensor | undefined;
40
                try {
41
                  tensor = loadTensorFromFile(dataFileFullPath);
42
                } catch (e) {
43
                  console.warn(`[${model}] Failed to load test data: ${e.message}`);
44
                }
45

46
                if (dataFile.indexOf('input') !== -1) {
47
                  inputs.push(tensor);
48
                } else if (dataFile.indexOf('output') !== -1) {
49
                  outputs.push(tensor);
50
                }
51
              }
52
            }
53
            modelTestCases.push([inputs, outputs]);
54
          }
55
        }
56

57
        // add cases
58
        describe(`${opset}/${model}`, () => {
59
          let session: InferenceSession | null = null;
60
          let skipModel = shouldSkipModel(model, opset, ['cpu']);
61
          if (!skipModel) {
62
            before(async () => {
63
              try {
64
                session = await InferenceSession.create(modelPath);
65
              } catch (e) {
66
                // By default ort allows models with opsets from an official onnx release only. If it encounters
67
                // a model with opset > than released opset, ValidateOpsetForDomain throws an error and model load
68
                // fails. Since this is by design such a failure is acceptable in the context of this test. Therefore we
69
                // simply skip this test. Setting env variable ALLOW_RELEASED_ONNX_OPSET_ONLY=0 allows loading a model
70
                // with opset > released onnx opset.
71
                if (
72
                  process.env.ALLOW_RELEASED_ONNX_OPSET_ONLY !== '0' &&
73
                  e.message.includes('ValidateOpsetForDomain')
74
                ) {
75
                  session = null;
76
                  console.log(`Skipping ${model}. To run this test set env variable ALLOW_RELEASED_ONNX_OPSET_ONLY=0`);
77
                  skipModel = true;
78
                } else {
79
                  throw e;
80
                }
81
              }
82
            });
83
          } else {
84
            console.log(`[test-runner] skipped: ${model}`);
85
          }
86

87
          for (let i = 0; i < modelTestCases.length; i++) {
88
            const testCase = modelTestCases[i];
89
            const inputs = testCase[0];
90
            const expectedOutputs = testCase[1];
91
            if (!skipModel && !inputs.some((t) => t === undefined) && !expectedOutputs.some((t) => t === undefined)) {
92
              it(`case${i}`, async () => {
93
                if (skipModel) {
94
                  return;
95
                }
96

97
                if (session !== null) {
98
                  const feeds: Record<string, Tensor> = {};
99
                  if (inputs.length !== session.inputNames.length) {
100
                    throw new RangeError('input length does not match name list');
101
                  }
102
                  for (let i = 0; i < inputs.length; i++) {
103
                    feeds[session.inputNames[i]] = inputs[i]!;
104
                  }
105
                  const outputs = await session.run(feeds);
106

107
                  let j = 0;
108
                  for (const name of session.outputNames) {
109
                    assertTensorEqual(outputs[name], expectedOutputs[j++]!, atol(model), rtol(model));
110
                  }
111
                } else {
112
                  throw new TypeError('session is null');
113
                }
114
              });
115
            }
116
          }
117

118
          if (!skipModel) {
119
            after(async () => {
120
              if (session !== null) {
121
                await session.release();
122
              }
123
            });
124
          }
125
        });
126
      }
127
    }
128
  }
129
}
130

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.