onnxruntime

Форк
0
101 строка · 3.3 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
// TODO: this is the same naive implementation we use for reduce that has
5
// performance limitations when the reduced axis is long. Need to add
6
// a optimized codepath for this.
7

8
import { DataType } from '../../../wasm-common';
9
import { TensorView } from '../../tensor-view';
10
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
11
import { ComputeContext } from '../types';
12

13
import { createReduceProgramInfo, ReduceOp } from './reduce';
14

15
const validateInputs = (inputs: readonly TensorView[]): void => {
16
  if (!inputs || inputs.length === 0 || inputs.length > 2) {
17
    throw new Error('ArgMinMaxOp op requires 1 or 2 inputs.');
18
  }
19
  if (inputs[0].dataType !== DataType.float) {
20
    throw new Error('Invalid input type.');
21
  }
22
};
23

24
export interface ArgMinMaxAttributes extends AttributeWithCacheKey {
25
  keepDims: boolean;
26
  axis: number;
27
  selectLastIndex: number;
28
}
29

30
export const argMin = (context: ComputeContext, attributes: ArgMinMaxAttributes): void => {
31
  validateInputs(context.inputs);
32
  const argMinMaxOp: ReduceOp = (input, output, axes) => {
33
    const idxZero = [];
34
    for (let k = 0; k < input.rank; k++) {
35
      if (axes.indexOf(k) >= 0 || axes.length === 0) {
36
        idxZero.push(`input_indices[${k}] = 0;`); // first element
37
      }
38
    }
39
    return [
40
      `${idxZero.join('\n')}`,
41
      `var value = ${input.getByIndices('input_indices')};\nvar best_index : i32 = 0;`,
42
      `if (${input.getByIndices('input_indices')} ${attributes.selectLastIndex > 0 ? '<=' : '<'} value) {
43
         value = ${input.getByIndices('input_indices')};
44
         best_index = i32(last_index);
45
       }`,
46
      '',
47
      output.setByOffset('global_idx', 'best_index'),
48
    ];
49
  };
50

51
  context.compute(
52
    createReduceProgramInfo(
53
      'ArgMin',
54
      { hint: attributes.cacheKey, inputDependencies: ['rank'] },
55
      [context.inputs[0]],
56
      argMinMaxOp,
57
      [attributes.axis],
58
      DataType.int64,
59
      attributes.keepDims,
60
    ),
61
    { inputs: [0] },
62
  );
63
};
64

65
export const argMax = (context: ComputeContext, attributes: ArgMinMaxAttributes): void => {
66
  validateInputs(context.inputs);
67
  const argMinMaxOp: ReduceOp = (input, output, axes) => {
68
    const idxZero = [];
69
    for (let k = 0; k < input.rank; k++) {
70
      if (axes.indexOf(k) >= 0 || axes.length === 0) {
71
        idxZero.push(`input_indices[${k}] = 0;`); // first element
72
      }
73
    }
74
    return [
75
      `${idxZero.join('\n')}`,
76
      `var value = ${input.getByIndices('input_indices')};\nvar best_index : i32 = 0;`,
77
      `if (${input.getByIndices('input_indices')} ${attributes.selectLastIndex > 0 ? '>=' : '>'} value) {
78
         value = ${input.getByIndices('input_indices')};
79
         best_index = i32(last_index);
80
       }`,
81
      '',
82
      output.setByOffset('global_idx', 'best_index'),
83
    ];
84
  };
85

86
  context.compute(
87
    createReduceProgramInfo(
88
      'argMax',
89
      { hint: attributes.cacheKey, inputDependencies: ['rank'] },
90
      [context.inputs[0]],
91
      argMinMaxOp,
92
      [attributes.axis],
93
      DataType.int64,
94
      attributes.keepDims,
95
    ),
96
    { inputs: [0] },
97
  );
98
};
99

100
export const parseArgMinMaxAttributes = (attributes: Record<string, unknown>): ArgMinMaxAttributes =>
101
  createAttributeWithCacheKey(attributes as Omit<ArgMinMaxAttributes, keyof AttributeWithCacheKey>);
102

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.