onnxruntime

Форк
0
160 строк · 5.6 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
// TODO: this is the same naive implementation we use for reduce that has
5
// performance limitations when the reduced axis is long. Need to add
6
// a optimized codepath for this.
7

8
import { DataType } from '../../../wasm-common';
9
import { TensorView } from '../../tensor-view';
10
import { ShapeUtil } from '../../util';
11
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
12
import { ComputeContext, ProgramInfo } from '../types';
13

14
import {
15
  getMaxComponents,
16
  inputVariable,
17
  outputVariable,
18
  ShaderHelper,
19
  sumVector,
20
  tensorTypeToWsglStorageType,
21
} from './common';
22

23
const validateInputs = (inputs: readonly TensorView[]): void => {
24
  if (!inputs || inputs.length !== 1) {
25
    throw new Error('Softmax op requires 1 input.');
26
  }
27
};
28

29
export interface SoftmaxAttributes extends AttributeWithCacheKey {
30
  readonly axis: number;
31
}
32

33
const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttributes): ProgramInfo => {
34
  const shape = input.dims;
35
  const outputSize = ShapeUtil.size(shape);
36
  const WG = 64;
37
  let axis = attributes.axis;
38
  if (axis < 0) {
39
    axis = shape.length + axis;
40
  }
41
  if (axis < shape.length - 1) {
42
    throw new Error('softmax only supports last axis for now.');
43
  }
44

45
  const cols = shape[axis];
46
  const rows = outputSize / cols;
47
  const components = getMaxComponents(cols);
48
  const packedCols = cols / components;
49

50
  const maxVector = (name: string, components: number) => {
51
    if (components === 4) {
52
      return `max(max(${name}.x, ${name}.y), max(${name}.z, ${name}.w))`;
53
    } else if (components === 2) {
54
      return `max(${name}.x, ${name}.y)`;
55
    } else if (components === 3) {
56
      return `max(max(${name}.x, ${name}.y), ${name}.z)`;
57
    }
58

59
    return name;
60
  };
61
  const x = inputVariable('x', input.dataType, input.dims, components);
62
  const output = outputVariable('result', input.dataType, input.dims, components);
63
  const valueType = x.type.value;
64
  // 6.2.4 in wgsl spec
65
  const threadMaxDecl =
66
    tensorTypeToWsglStorageType(input.dataType) === 'f32'
67
      ? `var threadMax = ${valueType}(-3.402823e+38f);`
68
      : `var threadMax = ${valueType}(-65504.0h);`;
69
  const getShaderSource = (shaderHelper: ShaderHelper) => `
70
      var<workgroup> rowMaxShared : ${valueType};
71
      var<workgroup> rowSumShared : ${valueType};
72
      var<workgroup> threadShared : array<${valueType}, ${WG}>;
73

74
      fn getValue(row: i32, col: i32, row_stride: i32) -> ${valueType} {
75
        let index = row * row_stride + col;
76
        return x[index];
77
      }
78

79
      fn setValue(row: i32, col: i32, row_stride: i32, value: ${valueType}) {
80
        let index = row * row_stride + col;
81
        result[index] = value;
82
      }
83
      ${shaderHelper.registerUniform('packedCols', 'i32').declareVariables(x, output)}
84
      ${shaderHelper.mainStart()}
85
        let gindex = i32(global_idx);
86
        let lindex = i32(local_idx);
87
        const wg = ${WG};
88
        let row = gindex / wg;
89
        let cols = uniforms.packedCols;
90
        let row_stride : i32 = uniforms.packedCols;
91

92
        // find the rows max
93
        ${threadMaxDecl}
94
        for (var col = lindex; col < cols; col += wg) {
95
          let value = getValue(row, col, row_stride);
96
          threadMax = max(threadMax, value);
97
        }
98
        if (lindex < cols) {
99
          threadShared[lindex] = threadMax;
100
        }
101
        workgroupBarrier();
102

103
        var reduceSize = min(cols, wg);
104
        for (var currSize = reduceSize >> 1;  currSize > 0; currSize = reduceSize >> 1) {
105
          reduceSize = currSize + (reduceSize & 1);
106
          if (lindex < currSize) {
107
            threadShared[lindex] = max(threadShared[lindex], threadShared[lindex + reduceSize]);
108
          }
109
          workgroupBarrier();
110
        }
111
        if (lindex == 0) {
112
          rowMaxShared = ${valueType}(${maxVector('threadShared[0]', components)});
113
        }
114
        workgroupBarrier();
115

116
        // find the rows sum
117
        var threadSum = ${valueType}(0.0);
118
        for (var col = lindex; col < cols; col += wg) {
119
          let subExp = exp(getValue(row, col, row_stride) - rowMaxShared);
120
          threadSum += subExp;
121
        }
122
        threadShared[lindex] = threadSum;
123
        workgroupBarrier();
124

125
        for (var currSize = wg >> 1;  currSize > 0; currSize = currSize >> 1) {
126
          if (lindex < currSize) {
127
            threadShared[lindex] = threadShared[lindex] + threadShared[lindex + currSize];
128
          }
129
          workgroupBarrier();
130
        }
131
        if (lindex == 0) {
132
          rowSumShared = ${valueType}(${sumVector('threadShared[0]', components)});
133
        }
134
        workgroupBarrier();
135

136
        // calculate final value for each element in the row
137
        for (var col = lindex; col < cols; col += wg) {
138
          let value = exp(getValue(row, col, row_stride) - rowMaxShared) / rowSumShared;
139
          setValue(row, col, row_stride, value);
140
        }
141
      }`;
142
  return {
143
    name: 'Softmax',
144
    shaderCache: { hint: `${components}`, inputDependencies: ['type'] },
145
    getRunData: () => ({
146
      outputs: [{ dims: shape, dataType: input.dataType }],
147
      dispatchGroup: { x: rows },
148
      programUniforms: [{ type: DataType.int32, data: packedCols }],
149
    }),
150
    getShaderSource,
151
  };
152
};
153

154
export const softmax = (context: ComputeContext, attributes: SoftmaxAttributes): void => {
155
  validateInputs(context.inputs);
156
  context.compute(createSoftmaxProgramInfo(context.inputs[0], attributes));
157
};
158

159
export const parseSoftmaxAttributes = (attributes: Record<string, unknown>): SoftmaxAttributes =>
160
  createAttributeWithCacheKey({ axis: attributes.axis as number });
161

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.