onnxruntime

Форк
0
339 строк · 11.1 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { ShapeUtil } from '../../../util';
9
import { getGlsl } from '../glsl-source';
10
import { WebGLInferenceHandler } from '../inference-handler';
11
import { ProgramInfo, TextureType } from '../types';
12

13
import { transpose, TransposeAttributes } from './transpose';
14

15
export interface SoftmaxAttributes extends AttributeWithCacheKey {
16
  readonly axis: number;
17
}
18

19
const softmaxComputeMaxProgramMetadata = {
20
  name: 'SoftmaxComputeMax',
21
  inputNames: ['A'],
22
  inputTypes: [TextureType.unpacked],
23
};
24

25
const softmaxComputeScaleProgramMetadata = {
26
  name: 'SoftmaxComputeScale',
27
  inputNames: ['A', 'Max'],
28
  inputTypes: [TextureType.unpacked, TextureType.unpacked],
29
};
30

31
const softmaxProgramMetadata = {
32
  name: 'SoftMax',
33
  inputNames: ['A', 'Max', 'Norm'],
34
  inputTypes: [TextureType.unpacked, TextureType.unpacked, TextureType.unpacked],
35
};
36

37
export const softmax: OperatorImplementation<SoftmaxAttributes> = (
38
  inferenceHandler: WebGLInferenceHandler,
39
  inputs: Tensor[],
40
  attributes: SoftmaxAttributes,
41
): Tensor[] => {
42
  validateInputs(inputs);
43

44
  const inputShape = inputs[0].dims.slice();
45
  const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length);
46
  const logicalRowCount = ShapeUtil.sizeToDimension(inputShape, axis);
47
  const featureCount = ShapeUtil.sizeFromDimension(inputShape, axis);
48

49
  const output = computeSoftmax(inferenceHandler, inputs, attributes, logicalRowCount, featureCount);
50
  return output;
51
};
52

53
export const parseSoftmaxAttributes: OperatorInitialization<SoftmaxAttributes> = (
54
  node: Graph.Node,
55
): SoftmaxAttributes => createAttributeWithCacheKey({ axis: node.attributes.getInt('axis', 1) });
56

57
export const parseSoftmaxAttributesV13: OperatorInitialization<SoftmaxAttributes> = (
58
  node: Graph.Node,
59
): SoftmaxAttributes => createAttributeWithCacheKey({ axis: node.attributes.getInt('axis', -1) });
60

61
// The "semantic" meaning of axis has changed in opset-13.
62
// Please compare: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Softmax
63
// with https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Softmax-11 for detailed explanations
64
// To account for the opset-13 behavior, our plan will be to transpose the "axis" dim to the innermost dim
65
// and perform softmax and then reverse the transpose. We can skip the transposing aspect if the axis is already
66
// the innermost dim
67
export const softmaxV13: OperatorImplementation<SoftmaxAttributes> = (
68
  inferenceHandler: WebGLInferenceHandler,
69
  inputs: Tensor[],
70
  attributes: SoftmaxAttributes,
71
): Tensor[] => {
72
  validateInputs(inputs);
73

74
  const inputShape = inputs[0].dims.slice();
75
  const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length);
76
  const rank = inputShape.length;
77

78
  const isTransposeRequired = axis !== rank - 1 ? true : false;
79
  const transposedInputShape: number[] = [];
80
  let perm: number[] = [];
81
  let transposedInputs: Tensor[] = [];
82
  let transposeAttribute: TransposeAttributes;
83

84
  if (isTransposeRequired) {
85
    perm = Array.from({ length: rank }).map((_, i) => i);
86

87
    // swap the innermost dim with the dim corresponding to axis
88
    perm[axis] = rank - 1;
89
    perm[rank - 1] = axis;
90

91
    perm.map((p) => transposedInputShape.push(inputShape[p]));
92

93
    transposeAttribute = createAttributeWithCacheKey({ perm });
94
    transposedInputs = transpose(inferenceHandler, inputs, transposeAttribute);
95
  }
96

97
  const logicalRowCount = isTransposeRequired
98
    ? ShapeUtil.sizeToDimension(transposedInputShape, rank - 1)
99
    : ShapeUtil.sizeToDimension(inputShape, rank - 1);
100
  const featureCount = isTransposeRequired
101
    ? ShapeUtil.sizeFromDimension(transposedInputShape, rank - 1)
102
    : ShapeUtil.sizeFromDimension(inputShape, rank - 1);
103

104
  const output = computeSoftmax(
105
    inferenceHandler,
106
    isTransposeRequired ? transposedInputs : inputs,
107
    attributes,
108
    logicalRowCount,
109
    featureCount,
110
  );
111

112
  if (isTransposeRequired) {
113
    const reversedOutput = transpose(inferenceHandler, output, transposeAttribute!);
114
    return reversedOutput;
115
  } else {
116
    return output;
117
  }
118
};
119

120
const computeSoftmax = (
121
  inferenceHandler: WebGLInferenceHandler,
122
  inputs: Tensor[],
123
  attributes: SoftmaxAttributes,
124
  logicalRowCount: number,
125
  featureCount: number,
126
): Tensor[] => {
127
  const computeMaxProgramInfo = createComputeMaxProgramInfo(
128
    inferenceHandler,
129
    inputs[0],
130
    logicalRowCount,
131
    featureCount,
132
    [logicalRowCount],
133
  );
134
  const max = inferenceHandler.run(
135
    { ...softmaxComputeMaxProgramMetadata, cacheHint: attributes.cacheKey, get: () => computeMaxProgramInfo },
136
    inputs,
137
  );
138

139
  const computeScaleProgramInfo = createComputScaleProgramInfo(
140
    inferenceHandler,
141
    inputs[0],
142
    logicalRowCount,
143
    featureCount,
144
    computeMaxProgramInfo.output.dims,
145
    [logicalRowCount],
146
  );
147
  const scale = inferenceHandler.run(
148
    { ...softmaxComputeScaleProgramMetadata, cacheHint: attributes.cacheKey, get: () => computeScaleProgramInfo },
149
    [inputs[0], max],
150
  );
151

152
  const softMaxProgramInfo = createSoftMaxProgramInfo(
153
    inferenceHandler,
154
    inputs[0],
155
    logicalRowCount,
156
    featureCount,
157
    computeMaxProgramInfo.output.dims,
158
    computeScaleProgramInfo.output.dims,
159
  );
160
  const output = inferenceHandler.run(
161
    { ...softmaxProgramMetadata, cacheHint: attributes.cacheKey, get: () => softMaxProgramInfo },
162
    [inputs[0], max, scale],
163
  );
164
  return [output];
165
};
166

167
/**
168
 * Create a texture that contains the maximum value of each of the 'N' rows
169
 */
170
const createComputeMaxProgramInfo = (
171
  inferenceHandler: WebGLInferenceHandler,
172
  input: Tensor,
173
  logicalRowCount: number,
174
  featureCount: number,
175
  outputShape: number[],
176
): ProgramInfo => {
177
  const [textureWidth, textureHeight] = inferenceHandler.calculateTextureWidthAndHeight(
178
    input.dims,
179
    TextureType.unpacked,
180
  );
181
  const rank = outputShape.length;
182

183
  if (logicalRowCount < 1 || featureCount < 1) {
184
    throw new Error('Logical row count N and feature count D must be greater than or equal to 1');
185
  }
186

187
  if (outputShape.length !== 1) {
188
    throw new Error('Dimensionality of the output should be 1');
189
  }
190

191
  if (outputShape[0] !== logicalRowCount) {
192
    throw new Error('Shape of the output should be equal to logical row count');
193
  }
194

195
  const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
196
  const shaderSource = `
197
      float process(int[${rank}] indices) {
198
        int logical_row_start_offset = indices[0] * ${featureCount};
199

200
        float max = getColorAsFloat(${glsl.texture2D}(A, offsetToCoords(logical_row_start_offset, ${textureWidth},
201
        ${textureHeight} )));
202
        for(int i=1; i<${featureCount}; ++i)
203
        {
204
          float current = getColorAsFloat(${glsl.texture2D}(A, offsetToCoords(logical_row_start_offset + i,
205
            ${textureWidth}, ${textureHeight})));
206
          if(current > max)
207
          max = current;
208
        }
209

210
        return max;
211
      }`;
212
  return {
213
    ...softmaxComputeMaxProgramMetadata,
214
    output: { dims: outputShape, type: input.type, textureType: TextureType.unpacked },
215
    shaderSource,
216
  };
217
};
218

219
/**
220
 * Create a texture that contains the normalization factor for each of the 'N' rows
221
 */
222
const createComputScaleProgramInfo = (
223
  inferenceHandler: WebGLInferenceHandler,
224
  input: Tensor,
225
  logicalRowCount: number,
226
  featureCount: number,
227
  maxElementPerLogicalRow: readonly number[],
228
  outputShape: number[],
229
): ProgramInfo => {
230
  const [textureWidth, textureHeight] = inferenceHandler.calculateTextureWidthAndHeight(
231
    input.dims,
232
    TextureType.unpacked,
233
  );
234
  const rank = outputShape.length;
235

236
  if (logicalRowCount < 1 || featureCount < 1) {
237
    throw new Error('Logical row count N and feature count D must be greater than or equal to 1');
238
  }
239

240
  if (outputShape.length !== 1) {
241
    throw new Error('Dimensionality of the output should be 1');
242
  }
243

244
  if (outputShape[0] !== logicalRowCount) {
245
    throw new Error('Shape of the output should be equal to logical row count');
246
  }
247

248
  if (maxElementPerLogicalRow.length !== 1) {
249
    throw new Error('Dimensionality of the intermediate results should be 1');
250
  }
251

252
  if (maxElementPerLogicalRow[0] !== logicalRowCount) {
253
    throw new Error('Shape of the intermediate results should be equal to logical row count');
254
  }
255

256
  const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
257
  const shaderSource = `
258
      float process(int[${rank}] indices) {
259
        int logical_row_start_offset = indices[0] * ${featureCount};
260

261
        float norm_factor = 0.0;
262
        float max = _Max(indices);
263
        for(int i=0; i<${featureCount}; ++i)
264
        {
265
          norm_factor += exp(getColorAsFloat(${glsl.texture2D}(A, offsetToCoords(logical_row_start_offset + i,
266
            ${textureWidth}, ${textureHeight}))) - max);
267
        }
268

269
        return norm_factor;
270
      }`;
271
  return {
272
    ...softmaxComputeScaleProgramMetadata,
273
    output: { dims: outputShape, type: input.type, textureType: TextureType.unpacked },
274
    shaderSource,
275
  };
276
};
277

278
const createSoftMaxProgramInfo = (
279
  inferenceHandler: WebGLInferenceHandler,
280
  input: Tensor,
281
  logicalRowCount: number,
282
  featureCount: number,
283
  maxElementPerLogicalRow: readonly number[],
284
  normalizationPerLogicalRow: readonly number[],
285
): ProgramInfo => {
286
  const [textureWidth, textureHeight] = inferenceHandler.calculateTextureWidthAndHeight(
287
    input.dims,
288
    TextureType.unpacked,
289
  );
290
  const rank = input.dims.length;
291

292
  if (logicalRowCount < 1 || featureCount < 1) {
293
    throw new Error('Logical row count N and feature count D must be greater than or equal to 1');
294
  }
295

296
  if (maxElementPerLogicalRow.length !== 1 || normalizationPerLogicalRow.length !== 1) {
297
    throw new Error('Dimensionality of the intermediate results should be 1');
298
  }
299

300
  if (maxElementPerLogicalRow[0] !== logicalRowCount || normalizationPerLogicalRow[0] !== logicalRowCount) {
301
    throw new Error('Shape of the intermediate results should be equal to logical row count');
302
  }
303

304
  const shaderSource = `
305
      float process(int[${rank}] indices) {
306

307
      // get offset of current logical tensor index from the 2-D texture coordinates (TexCoords)
308
      int offset = coordsToOffset(TexCoords, ${textureWidth}, ${textureHeight});
309

310
      //determine the logical row for this index
311
      int logical_row_index[1];
312
      logical_row_index[0] = offset / ${featureCount};
313

314
      float norm_factor = _Norm(logical_row_index);
315

316
      // avoid possible division by 0
317
      // if norm_facor is 0, all elements are zero
318
      // if so, return 0
319
      if(norm_factor == 0.0)
320
        return 0.0;
321

322
      return exp(_A(indices) - _Max(logical_row_index)) / norm_factor;
323
    }`;
324
  return {
325
    ...softmaxProgramMetadata,
326
    output: { dims: input.dims, type: input.type, textureType: TextureType.unpacked },
327
    shaderSource,
328
  };
329
};
330

331
const validateInputs = (inputs: Tensor[]): void => {
332
  if (!inputs || inputs.length !== 1) {
333
    throw new Error('Softmax requires 1 input.');
334
  }
335

336
  if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') {
337
    throw new Error('Invalid input type');
338
  }
339
};
340

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.