1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
// TODO: this is the same naive implementation we use for reduce that has
5
// performance limitations when the reduced axis is long. Need to add
6
// a optimized codepath for this.
8
import { DataType } from '../../../wasm-common';
9
import { TensorView } from '../../tensor-view';
10
import { ShapeUtil } from '../../util';
11
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
12
import { ComputeContext, ProgramInfo } from '../types';
20
tensorTypeToWsglStorageType,
23
const validateInputs = (inputs: readonly TensorView[]): void => {
24
if (!inputs || inputs.length !== 1) {
25
throw new Error('Softmax op requires 1 input.');
29
export interface SoftmaxAttributes extends AttributeWithCacheKey {
30
readonly axis: number;
33
const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttributes): ProgramInfo => {
34
const shape = input.dims;
35
const outputSize = ShapeUtil.size(shape);
37
let axis = attributes.axis;
39
axis = shape.length + axis;
41
if (axis < shape.length - 1) {
42
throw new Error('softmax only supports last axis for now.');
45
const cols = shape[axis];
46
const rows = outputSize / cols;
47
const components = getMaxComponents(cols);
48
const packedCols = cols / components;
50
const maxVector = (name: string, components: number) => {
51
if (components === 4) {
52
return `max(max(${name}.x, ${name}.y), max(${name}.z, ${name}.w))`;
53
} else if (components === 2) {
54
return `max(${name}.x, ${name}.y)`;
55
} else if (components === 3) {
56
return `max(max(${name}.x, ${name}.y), ${name}.z)`;
61
const x = inputVariable('x', input.dataType, input.dims, components);
62
const output = outputVariable('result', input.dataType, input.dims, components);
63
const valueType = x.type.value;
66
tensorTypeToWsglStorageType(input.dataType) === 'f32'
67
? `var threadMax = ${valueType}(-3.402823e+38f);`
68
: `var threadMax = ${valueType}(-65504.0h);`;
69
const getShaderSource = (shaderHelper: ShaderHelper) => `
70
var<workgroup> rowMaxShared : ${valueType};
71
var<workgroup> rowSumShared : ${valueType};
72
var<workgroup> threadShared : array<${valueType}, ${WG}>;
74
fn getValue(row: i32, col: i32, row_stride: i32) -> ${valueType} {
75
let index = row * row_stride + col;
79
fn setValue(row: i32, col: i32, row_stride: i32, value: ${valueType}) {
80
let index = row * row_stride + col;
81
result[index] = value;
83
${shaderHelper.registerUniform('packedCols', 'i32').declareVariables(x, output)}
84
${shaderHelper.mainStart()}
85
let gindex = i32(global_idx);
86
let lindex = i32(local_idx);
88
let row = gindex / wg;
89
let cols = uniforms.packedCols;
90
let row_stride : i32 = uniforms.packedCols;
94
for (var col = lindex; col < cols; col += wg) {
95
let value = getValue(row, col, row_stride);
96
threadMax = max(threadMax, value);
99
threadShared[lindex] = threadMax;
103
var reduceSize = min(cols, wg);
104
for (var currSize = reduceSize >> 1; currSize > 0; currSize = reduceSize >> 1) {
105
reduceSize = currSize + (reduceSize & 1);
106
if (lindex < currSize) {
107
threadShared[lindex] = max(threadShared[lindex], threadShared[lindex + reduceSize]);
112
rowMaxShared = ${valueType}(${maxVector('threadShared[0]', components)});
117
var threadSum = ${valueType}(0.0);
118
for (var col = lindex; col < cols; col += wg) {
119
let subExp = exp(getValue(row, col, row_stride) - rowMaxShared);
122
threadShared[lindex] = threadSum;
125
for (var currSize = wg >> 1; currSize > 0; currSize = currSize >> 1) {
126
if (lindex < currSize) {
127
threadShared[lindex] = threadShared[lindex] + threadShared[lindex + currSize];
132
rowSumShared = ${valueType}(${sumVector('threadShared[0]', components)});
136
// calculate final value for each element in the row
137
for (var col = lindex; col < cols; col += wg) {
138
let value = exp(getValue(row, col, row_stride) - rowMaxShared) / rowSumShared;
139
setValue(row, col, row_stride, value);
144
shaderCache: { hint: `${components}`, inputDependencies: ['type'] },
146
outputs: [{ dims: shape, dataType: input.dataType }],
147
dispatchGroup: { x: rows },
148
programUniforms: [{ type: DataType.int32, data: packedCols }],
154
export const softmax = (context: ComputeContext, attributes: SoftmaxAttributes): void => {
155
validateInputs(context.inputs);
156
context.compute(createSoftmaxProgramInfo(context.inputs[0], attributes));
159
export const parseSoftmaxAttributes = (attributes: Record<string, unknown>): SoftmaxAttributes =>
160
createAttributeWithCacheKey({ axis: attributes.axis as number });