1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { ShapeUtil } from '../../../util';
9
import { getGlsl } from '../glsl-source';
10
import { WebGLInferenceHandler } from '../inference-handler';
11
import { ProgramInfo, TextureType } from '../types';
13
import { transpose, TransposeAttributes } from './transpose';
15
export interface SoftmaxAttributes extends AttributeWithCacheKey {
16
readonly axis: number;
19
const softmaxComputeMaxProgramMetadata = {
20
name: 'SoftmaxComputeMax',
22
inputTypes: [TextureType.unpacked],
25
const softmaxComputeScaleProgramMetadata = {
26
name: 'SoftmaxComputeScale',
27
inputNames: ['A', 'Max'],
28
inputTypes: [TextureType.unpacked, TextureType.unpacked],
31
const softmaxProgramMetadata = {
33
inputNames: ['A', 'Max', 'Norm'],
34
inputTypes: [TextureType.unpacked, TextureType.unpacked, TextureType.unpacked],
37
export const softmax: OperatorImplementation<SoftmaxAttributes> = (
38
inferenceHandler: WebGLInferenceHandler,
40
attributes: SoftmaxAttributes,
42
validateInputs(inputs);
44
const inputShape = inputs[0].dims.slice();
45
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length);
46
const logicalRowCount = ShapeUtil.sizeToDimension(inputShape, axis);
47
const featureCount = ShapeUtil.sizeFromDimension(inputShape, axis);
49
const output = computeSoftmax(inferenceHandler, inputs, attributes, logicalRowCount, featureCount);
53
export const parseSoftmaxAttributes: OperatorInitialization<SoftmaxAttributes> = (
55
): SoftmaxAttributes => createAttributeWithCacheKey({ axis: node.attributes.getInt('axis', 1) });
57
export const parseSoftmaxAttributesV13: OperatorInitialization<SoftmaxAttributes> = (
59
): SoftmaxAttributes => createAttributeWithCacheKey({ axis: node.attributes.getInt('axis', -1) });
61
// The "semantic" meaning of axis has changed in opset-13.
62
// Please compare: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Softmax
63
// with https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Softmax-11 for detailed explanations
64
// To account for the opset-13 behavior, our plan will be to transpose the "axis" dim to the innermost dim
65
// and perform softmax and then reverse the transpose. We can skip the transposing aspect if the axis is already
67
export const softmaxV13: OperatorImplementation<SoftmaxAttributes> = (
68
inferenceHandler: WebGLInferenceHandler,
70
attributes: SoftmaxAttributes,
72
validateInputs(inputs);
74
const inputShape = inputs[0].dims.slice();
75
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length);
76
const rank = inputShape.length;
78
const isTransposeRequired = axis !== rank - 1 ? true : false;
79
const transposedInputShape: number[] = [];
80
let perm: number[] = [];
81
let transposedInputs: Tensor[] = [];
82
let transposeAttribute: TransposeAttributes;
84
if (isTransposeRequired) {
85
perm = Array.from({ length: rank }).map((_, i) => i);
87
// swap the innermost dim with the dim corresponding to axis
88
perm[axis] = rank - 1;
89
perm[rank - 1] = axis;
91
perm.map((p) => transposedInputShape.push(inputShape[p]));
93
transposeAttribute = createAttributeWithCacheKey({ perm });
94
transposedInputs = transpose(inferenceHandler, inputs, transposeAttribute);
97
const logicalRowCount = isTransposeRequired
98
? ShapeUtil.sizeToDimension(transposedInputShape, rank - 1)
99
: ShapeUtil.sizeToDimension(inputShape, rank - 1);
100
const featureCount = isTransposeRequired
101
? ShapeUtil.sizeFromDimension(transposedInputShape, rank - 1)
102
: ShapeUtil.sizeFromDimension(inputShape, rank - 1);
104
const output = computeSoftmax(
106
isTransposeRequired ? transposedInputs : inputs,
112
if (isTransposeRequired) {
113
const reversedOutput = transpose(inferenceHandler, output, transposeAttribute!);
114
return reversedOutput;
120
const computeSoftmax = (
121
inferenceHandler: WebGLInferenceHandler,
123
attributes: SoftmaxAttributes,
124
logicalRowCount: number,
125
featureCount: number,
127
const computeMaxProgramInfo = createComputeMaxProgramInfo(
134
const max = inferenceHandler.run(
135
{ ...softmaxComputeMaxProgramMetadata, cacheHint: attributes.cacheKey, get: () => computeMaxProgramInfo },
139
const computeScaleProgramInfo = createComputScaleProgramInfo(
144
computeMaxProgramInfo.output.dims,
147
const scale = inferenceHandler.run(
148
{ ...softmaxComputeScaleProgramMetadata, cacheHint: attributes.cacheKey, get: () => computeScaleProgramInfo },
152
const softMaxProgramInfo = createSoftMaxProgramInfo(
157
computeMaxProgramInfo.output.dims,
158
computeScaleProgramInfo.output.dims,
160
const output = inferenceHandler.run(
161
{ ...softmaxProgramMetadata, cacheHint: attributes.cacheKey, get: () => softMaxProgramInfo },
162
[inputs[0], max, scale],
168
* Create a texture that contains the maximum value of each of the 'N' rows
170
const createComputeMaxProgramInfo = (
171
inferenceHandler: WebGLInferenceHandler,
173
logicalRowCount: number,
174
featureCount: number,
175
outputShape: number[],
177
const [textureWidth, textureHeight] = inferenceHandler.calculateTextureWidthAndHeight(
179
TextureType.unpacked,
181
const rank = outputShape.length;
183
if (logicalRowCount < 1 || featureCount < 1) {
184
throw new Error('Logical row count N and feature count D must be greater than or equal to 1');
187
if (outputShape.length !== 1) {
188
throw new Error('Dimensionality of the output should be 1');
191
if (outputShape[0] !== logicalRowCount) {
192
throw new Error('Shape of the output should be equal to logical row count');
195
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
196
const shaderSource = `
197
float process(int[${rank}] indices) {
198
int logical_row_start_offset = indices[0] * ${featureCount};
200
float max = getColorAsFloat(${glsl.texture2D}(A, offsetToCoords(logical_row_start_offset, ${textureWidth},
201
${textureHeight} )));
202
for(int i=1; i<${featureCount}; ++i)
204
float current = getColorAsFloat(${glsl.texture2D}(A, offsetToCoords(logical_row_start_offset + i,
205
${textureWidth}, ${textureHeight})));
213
...softmaxComputeMaxProgramMetadata,
214
output: { dims: outputShape, type: input.type, textureType: TextureType.unpacked },
220
* Create a texture that contains the normalization factor for each of the 'N' rows
222
const createComputScaleProgramInfo = (
223
inferenceHandler: WebGLInferenceHandler,
225
logicalRowCount: number,
226
featureCount: number,
227
maxElementPerLogicalRow: readonly number[],
228
outputShape: number[],
230
const [textureWidth, textureHeight] = inferenceHandler.calculateTextureWidthAndHeight(
232
TextureType.unpacked,
234
const rank = outputShape.length;
236
if (logicalRowCount < 1 || featureCount < 1) {
237
throw new Error('Logical row count N and feature count D must be greater than or equal to 1');
240
if (outputShape.length !== 1) {
241
throw new Error('Dimensionality of the output should be 1');
244
if (outputShape[0] !== logicalRowCount) {
245
throw new Error('Shape of the output should be equal to logical row count');
248
if (maxElementPerLogicalRow.length !== 1) {
249
throw new Error('Dimensionality of the intermediate results should be 1');
252
if (maxElementPerLogicalRow[0] !== logicalRowCount) {
253
throw new Error('Shape of the intermediate results should be equal to logical row count');
256
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
257
const shaderSource = `
258
float process(int[${rank}] indices) {
259
int logical_row_start_offset = indices[0] * ${featureCount};
261
float norm_factor = 0.0;
262
float max = _Max(indices);
263
for(int i=0; i<${featureCount}; ++i)
265
norm_factor += exp(getColorAsFloat(${glsl.texture2D}(A, offsetToCoords(logical_row_start_offset + i,
266
${textureWidth}, ${textureHeight}))) - max);
272
...softmaxComputeScaleProgramMetadata,
273
output: { dims: outputShape, type: input.type, textureType: TextureType.unpacked },
278
const createSoftMaxProgramInfo = (
279
inferenceHandler: WebGLInferenceHandler,
281
logicalRowCount: number,
282
featureCount: number,
283
maxElementPerLogicalRow: readonly number[],
284
normalizationPerLogicalRow: readonly number[],
286
const [textureWidth, textureHeight] = inferenceHandler.calculateTextureWidthAndHeight(
288
TextureType.unpacked,
290
const rank = input.dims.length;
292
if (logicalRowCount < 1 || featureCount < 1) {
293
throw new Error('Logical row count N and feature count D must be greater than or equal to 1');
296
if (maxElementPerLogicalRow.length !== 1 || normalizationPerLogicalRow.length !== 1) {
297
throw new Error('Dimensionality of the intermediate results should be 1');
300
if (maxElementPerLogicalRow[0] !== logicalRowCount || normalizationPerLogicalRow[0] !== logicalRowCount) {
301
throw new Error('Shape of the intermediate results should be equal to logical row count');
304
const shaderSource = `
305
float process(int[${rank}] indices) {
307
// get offset of current logical tensor index from the 2-D texture coordinates (TexCoords)
308
int offset = coordsToOffset(TexCoords, ${textureWidth}, ${textureHeight});
310
//determine the logical row for this index
311
int logical_row_index[1];
312
logical_row_index[0] = offset / ${featureCount};
314
float norm_factor = _Norm(logical_row_index);
316
// avoid possible division by 0
317
// if norm_facor is 0, all elements are zero
319
if(norm_factor == 0.0)
322
return exp(_A(indices) - _Max(logical_row_index)) / norm_factor;
325
...softmaxProgramMetadata,
326
output: { dims: input.dims, type: input.type, textureType: TextureType.unpacked },
331
const validateInputs = (inputs: Tensor[]): void => {
332
if (!inputs || inputs.length !== 1) {
333
throw new Error('Softmax requires 1 input.');
336
if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') {
337
throw new Error('Invalid input type');