onnxruntime

Форк
0
195 строк · 7.1 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { WebGLInferenceHandler } from '../inference-handler';
9
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
10

11
import { createPackedConcatProgramInfoLoader } from './concat-packed';
12

13
export interface ConcatAttributes extends AttributeWithCacheKey {
14
  readonly axis: number;
15
}
16

17
export const concat: OperatorImplementation<ConcatAttributes> = (
18
  inferenceHandler: WebGLInferenceHandler,
19
  inputs: Tensor[],
20
  attributes: ConcatAttributes,
21
): Tensor[] => {
22
  validateInputs(inputs);
23
  if (inferenceHandler.session.pack && inputs[0].dims.length > 1) {
24
    const output = inferenceHandler.run(
25
      createPackedConcatProgramInfoLoader(inferenceHandler, inputs, attributes),
26
      inputs,
27
    );
28
    return [output];
29
  } else {
30
    const output = inferenceHandler.run(
31
      createUnpackedConcatProgramInfoLoader(inferenceHandler, inputs, attributes),
32
      inputs,
33
    );
34
    return [output];
35
  }
36
};
37

38
const createUnpackedConcatProgramMetadata = (inputCount: number, cacheHint: string) => ({
39
  name: 'Concat',
40
  inputNames: Array.from({ length: inputCount }, (_v, i) => `X${i}`),
41
  inputTypes: Array(inputCount).fill(TextureType.unpacked),
42
  cacheHint,
43
});
44

45
const createUnpackedConcatProgramInfo = (
46
  _handler: WebGLInferenceHandler,
47
  metadata: ProgramMetadata,
48
  inputs: Tensor[],
49
  axis: number,
50
): ProgramInfo => {
51
  const inputShape = inputs[0].dims.slice();
52
  if (axis >= inputShape.length || axis < -1 * inputShape.length) {
53
    throw new Error("axis specified for concat doesn't match input dimensionality");
54
  }
55
  if (axis < 0) {
56
    axis = inputShape.length + axis;
57
  }
58
  // ensure all of the non-concatenated axes match each other
59
  // calculate the shape of the output tensor while we do that
60
  const outputShape = inputShape.slice(0);
61
  for (let i = 1; i < inputs.length; i++) {
62
    const dataNShape = inputs[i].dims.slice();
63
    for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) {
64
      // add to the placeholder for computing output shape
65
      if (axisIndex === axis) {
66
        outputShape[axis] += dataNShape[axisIndex];
67
      }
68
      // ensure all non-cancatenated axes match each other
69
      else if (inputShape[axisIndex] !== dataNShape[axisIndex]) {
70
        throw new Error('non concat dimensions must match');
71
      }
72
    }
73
  }
74

75
  const rank = outputShape.length;
76

77
  const sizeInConcatAxis = new Array<number>(inputs.length);
78
  let previousSum = 0;
79
  for (let i = 0; i < sizeInConcatAxis.length; ++i) {
80
    previousSum += inputs[i].dims[axis];
81
    sizeInConcatAxis[i] = previousSum;
82
  }
83

84
  let getTextureIndexWhereDataResidesMethod = '';
85
  // in most cases linear search is sufficient, as in most scenarios, only 2 tensors are concatenated
86
  if (inputs.length < 5) {
87
    getTextureIndexWhereDataResidesMethod = getTextureIndexWhereDataResidesLinearSearch(sizeInConcatAxis);
88
  } else {
89
    getTextureIndexWhereDataResidesMethod = getTextureIndexWhereDataResidesBinarySearch(sizeInConcatAxis);
90
  }
91

92
  const fetchDataFromCorrectTextureMethod = getFetchDataFromCorrectTextureMethod(inputs.length, rank);
93
  const getSizeInConcatAxisValueFromIndexMethod = getGetSizeInConcatAxisValueFromIndexMethod(sizeInConcatAxis);
94
  const shaderSource = `
95
        ${fetchDataFromCorrectTextureMethod}
96
        ${getSizeInConcatAxisValueFromIndexMethod}
97
        ${getTextureIndexWhereDataResidesMethod}
98
        float process(int indices[${rank}]) {
99
          int textureIndex = getTextureWhereDataResides (indices[${axis}]);
100

101
          if(textureIndex != 0) {
102
            indices[${axis}] = indices[${axis}] - int(getSizeInConcatAxisValueFromIndex(textureIndex-int(1)));
103
          }
104

105
          return fetchDataFromCorrectTexture(textureIndex, indices);
106
        }`;
107
  return {
108
    ...metadata,
109
    output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked },
110
    shaderSource,
111
  };
112
};
113

114
const createUnpackedConcatProgramInfoLoader = (
115
  handler: WebGLInferenceHandler,
116
  inputs: Tensor[],
117
  attributes: ConcatAttributes,
118
): ProgramInfoLoader => {
119
  const metadata = createUnpackedConcatProgramMetadata(inputs.length, attributes.cacheKey);
120
  return { ...metadata, get: () => createUnpackedConcatProgramInfo(handler, metadata, inputs, attributes.axis) };
121
};
122

123
const getTextureIndexWhereDataResidesLinearSearch = (sizeInConcatAxis: number[]): string => {
124
  const searchAxis = sizeInConcatAxis.map(
125
    (size, i) => `if(index<${size}) {return ${i};}
126
`,
127
  );
128
  return `int getTextureWhereDataResides(int index) {
129
      ${searchAxis.join('')}
130
    }`;
131
};
132

133
// TODO: Implement BinarySearch in GLSL
134
const getTextureIndexWhereDataResidesBinarySearch = (sizeInConcatAxis: number[]): string =>
135
  getTextureIndexWhereDataResidesLinearSearch(sizeInConcatAxis);
136

137
const getFetchDataFromCorrectTextureMethod = (numberOfTensors: number, tensorRank: number) => {
138
  const codeLines: string[] = [`float fetchDataFromCorrectTexture(int textureIndex, int indices[${tensorRank}]) {`];
139
  for (let i = 0; i < numberOfTensors; ++i) {
140
    if (i === 0) {
141
      codeLines.push('\t' + `if (textureIndex == ${i}) { return _X${i}(indices); }`);
142
    } else if (i === numberOfTensors - 1) {
143
      codeLines.push('\t' + `else { return _X${i}(indices); }`);
144
    } else {
145
      codeLines.push('\t' + `else if (textureIndex == ${i}) { return _X${i}(indices); }`);
146
    }
147
  }
148
  codeLines.push('\t' + '}');
149
  return codeLines.join('\n');
150
};
151

152
const getGetSizeInConcatAxisValueFromIndexMethod = (sizeInConcatAxis: number[]): string => {
153
  const codeLines: string[] = ['int getSizeInConcatAxisValueFromIndex(int index) {'];
154
  for (let i = 0; i < sizeInConcatAxis.length; ++i) {
155
    if (i === 0) {
156
      codeLines.push('\t' + `if (index == ${i}) { return ${sizeInConcatAxis[i]}; }`);
157
    } else if (i === sizeInConcatAxis.length - 1) {
158
      codeLines.push('\t' + `else { return ${sizeInConcatAxis[i]}; }`);
159
    } else {
160
      codeLines.push('\t' + `else if (index == ${i}) { return ${sizeInConcatAxis[i]}; }`);
161
    }
162
  }
163
  codeLines.push('\t' + '}');
164

165
  return codeLines.join('\n');
166
};
167

168
export const parseConcatAttributes: OperatorInitialization<ConcatAttributes> = (node: Graph.Node): ConcatAttributes =>
169
  createAttributeWithCacheKey({ axis: node.attributes.getInt('axis') });
170

171
const validateInputs = (inputs: Tensor[]): void => {
172
  if (!inputs || inputs.length < 1) {
173
    throw new Error('too few inputs');
174
  }
175

176
  const inputType = inputs[0].type;
177
  const inputDimensionality = inputs[0].dims.length;
178

179
  // TODO: Support string concat
180
  if (inputType === 'string') {
181
    throw new Error('string tensor is not supported yet');
182
  }
183

184
  for (const input of inputs) {
185
    // make sure types of all inputs match
186
    if (input.type !== inputType) {
187
      throw new Error('input tensors should be one type');
188
    }
189

190
    // make sure the dimensionality of all inputs are the same
191
    if (input.dims.length !== inputDimensionality) {
192
      throw new Error('input tensors should have the same shape');
193
    }
194
  }
195
};
196

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.