onnxruntime

Форк
0
99 строк · 3.2 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { ShapeUtil } from '../../../util';
9
import { WebGLInferenceHandler } from '../inference-handler';
10
import { ProgramInfo, TextureType } from '../types';
11

12
export interface TransposeAttributes extends AttributeWithCacheKey {
13
  readonly perm: number[];
14
}
15

16
const transposeProgramMetadata = {
17
  name: 'Transpose',
18
  inputNames: ['A'],
19
  inputTypes: [TextureType.unpacked],
20
};
21

22
export const transpose: OperatorImplementation<TransposeAttributes> = (
23
  inferenceHandler: WebGLInferenceHandler,
24
  inputs: Tensor[],
25
  attributes: TransposeAttributes,
26
): Tensor[] => {
27
  validateInputs(inputs);
28
  const output = inferenceHandler.run(
29
    {
30
      ...transposeProgramMetadata,
31
      cacheHint: attributes.cacheKey,
32
      get: () => createTransposeProgramInfo(inferenceHandler, inputs[0], attributes.perm),
33
    },
34
    inputs,
35
  );
36
  return [output];
37
};
38

39
export const parseTransposeAttributes: OperatorInitialization<TransposeAttributes> = (
40
  node: Graph.Node,
41
): TransposeAttributes => createAttributeWithCacheKey({ perm: node.attributes.getInts('perm', []) });
42

43
const createTransposeProgramInfo = (
44
  _inferenceHandler: WebGLInferenceHandler,
45
  input: Tensor,
46
  perm: number[],
47
): ProgramInfo => {
48
  const inputShape = input.dims;
49
  perm = getAdjustedPerm(inputShape, perm);
50
  const unpackedOutputShape = getOutputShape(inputShape, perm);
51
  const rank = inputShape.length;
52
  // A dims=[${inputs[0].dims.toString()}]
53
  // out Dims=[${unpackedOutputShape.toString()}]
54
  // based on perm=[${perm.toString()}]
55
  const shaderSource = `
56
      ${getPermFunctionBody('perm', perm, rank)}
57
      float process(int indices[${rank}]) {
58
        int a[${rank}];
59
        perm(a, indices);
60
        return _A(a);
61
      }`;
62
  return {
63
    ...transposeProgramMetadata,
64
    output: { dims: unpackedOutputShape, type: input.type, textureType: TextureType.unpacked },
65
    shaderSource,
66
  };
67
};
68

69
const getAdjustedPerm = (inputShape: readonly number[], perm: number[]): number[] => {
70
  if (perm && perm.length !== inputShape.length) {
71
    perm = [...inputShape.keys()].reverse();
72
  }
73
  return perm;
74
};
75

76
const getOutputShape = (inputShape: readonly number[], perm: number[]): readonly number[] => {
77
  perm = getAdjustedPerm(inputShape, perm);
78
  return ShapeUtil.sortBasedOnPerm(inputShape, perm);
79
};
80

81
const getPermFunctionBody = (name: string, perm: number[], rank: number): string => {
82
  const reverseFunc = [];
83
  reverseFunc.push(`void ${name}(out int a[${rank}], int src[${rank}]) {`);
84
  for (let i = 0; i < rank; ++i) {
85
    reverseFunc.push(`\ta[${perm[i]}]=src[${i}];`);
86
  }
87
  reverseFunc.push('\t}');
88
  return reverseFunc.join('\n');
89
};
90

91
const validateInputs = (inputs: Tensor[]): void => {
92
  if (!inputs || inputs.length !== 1) {
93
    throw new Error('Transpose requires 1 input.');
94
  }
95

96
  if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') {
97
    throw new Error('input should be float tensor');
98
  }
99
};
100

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.