onnxruntime

Форк
0
/
reduce-shared.ts 
285 строк · 9.2 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { ComputeContext, ProgramInfo, ProgramShaderCacheInfo } from '../types';
8

9
import { inputVariable, outputVariable, ShaderHelper } from './common';
10
import { createReduceAttributesFromInputs, ReduceAttributes } from './reduce';
11
import { createTransposeProgramInfo } from './transpose';
12

13
const reduceOps: { [key: string]: string } = {
14
  max: 'select(bestValue, candidate, candidate > bestValue)',
15
  min: 'select(bestValue, candidate, candidate < bestValue)',
16
  mean: 'bestValue + candidate',
17
  sum: 'bestValue + candidate',
18
  prod: 'bestValue * candidate',
19
  sumSquare: 'bestValue + candidate * candidate',
20
  logSumExp: 'bestValue + exp(candidate)',
21
  l1: 'bestValue + abs(candidate)',
22
  l2: 'bestValue + candidate * candidate',
23
  logSum: 'bestValue + candidate',
24
};
25

26
const reduceSharedOps: { [key: string]: string } = {
27
  max: 'select(bestValue, candidate, candidate > bestValue)',
28
  min: 'select(bestValue, candidate, candidate < bestValue)',
29
  mean: 'bestValue + candidate',
30
  sum: 'bestValue + candidate',
31
  prod: 'bestValue * candidate',
32
  sumSquare: 'bestValue + candidate',
33
  logSumExp: 'bestValue + candidate',
34
  l1: 'bestValue + candidate',
35
  l2: 'bestValue + candidate',
36
  logSum: 'bestValue + candidate',
37
};
38

39
const reduceInitValues: { [key: string]: string } = {
40
  max: '_A[offset]',
41
  min: '_A[offset]',
42
  mean: '0',
43
  sum: '0',
44
  prod: '1',
45
  sumSquare: '0',
46
  logSumExp: '0',
47
  l1: '0',
48
  l2: '0',
49
  logSum: '0',
50
};
51

52
const reduceOutputValues: { [key: string]: string } = {
53
  max: 'bestValue',
54
  min: 'bestValue',
55
  sum: 'bestValue',
56
  prod: 'bestValue',
57
  sumSquare: 'bestValue',
58
  logSumExp: 'log(bestValue)',
59
  l1: 'bestValue',
60
  l2: 'sqrt(bestValue)',
61
  logSum: 'log(bestValue)',
62
};
63

64
const getInnerMostAxes = (numInnerAxes: number, rank: number): number[] => {
65
  const res = [];
66
  for (let i = rank - numInnerAxes; i < rank; ++i) {
67
    res.push(i);
68
  }
69
  return res;
70
};
71

72
const computeOutAndReduceShapes = (shape: readonly number[], axes: readonly number[]): [number[], number[]] => {
73
  const outputShape = [];
74
  const rank = shape.length;
75
  for (let dim = 0; dim < rank; dim++) {
76
    if (axes.indexOf(dim) === -1) {
77
      outputShape.push(shape[dim]);
78
    }
79
  }
80
  const reduceShape = axes.map((dim) => shape[dim]);
81
  return [outputShape, reduceShape];
82
};
83

84
const expandShapeToKeepDim = (shape: number[], axes: number[]): number[] => {
85
  const rank = shape.length + axes.length;
86
  const expandShape = [];
87
  let shapeIdx = 0;
88
  for (let dim = 0; dim < rank; dim++) {
89
    if (axes.indexOf(dim) === -1) {
90
      expandShape.push(shape[shapeIdx++]);
91
    } else {
92
      expandShape.push(1);
93
    }
94
  }
95
  return expandShape;
96
};
97

98
const areAxesInnerMostDims = (axes: number[], rank: number): boolean => {
99
  for (let i = 0; i < axes.length; ++i) {
100
    if (axes[axes.length - i - 1] !== rank - 1 - i) {
101
      return false;
102
    }
103
  }
104
  return true;
105
};
106

107
const getAxesPermutation = (axes: number[], rank: number): number[] => {
108
  const res = [];
109
  if (!areAxesInnerMostDims(axes, rank)) {
110
    for (let i = 0; i < rank; ++i) {
111
      if (axes.indexOf(i) === -1) {
112
        res.push(i);
113
      }
114
    }
115
    axes.forEach((axis) => res.push(axis));
116
  }
117
  return res;
118
};
119

120
export const createReduceSharedProgramInfo = (
121
  name: string,
122
  shaderCache: ProgramShaderCacheInfo,
123
  inputs: readonly TensorView[],
124
  reduceType: string,
125
  outputDataType: DataType,
126
  outputShape: number[],
127
  reduceShape: number[],
128
): ProgramInfo => {
129
  const inputShape = inputs[0].dims;
130

131
  const outputSize = ShapeUtil.size(outputShape);
132
  const reduceSize = ShapeUtil.size(reduceShape);
133

134
  const input = inputVariable('_A', inputs[0].dataType, inputShape);
135
  const output = outputVariable('output', outputDataType, outputShape);
136

137
  const workgroupSize = 32;
138

139
  const sharedMemorySnippet = `
140
          var<workgroup> aBestValues : array<f32, ${workgroupSize}>;
141
       `;
142

143
  const getShaderSource = (shaderHelper: ShaderHelper) => `
144
        ${shaderHelper.registerUniform('reduceSize', 'u32').declareVariables(input, output)}
145
        ${sharedMemorySnippet}
146
        fn DIV_CEIL(a : u32, b : u32) -> u32 {
147
          return ((a - 1u) / b + 1u);
148
         }
149
         ${shaderHelper.mainStart(workgroupSize)}
150

151
          let outputIndex = global_idx / ${workgroupSize};
152
          let offset = outputIndex * uniforms.reduceSize;
153

154
          var bestValue = f32(${reduceInitValues[reduceType]});
155
          let Length = uniforms.reduceSize;
156
          for (var k = local_idx; k < Length; k = k + ${workgroupSize}) {
157
           let candidate = f32(${input.getByOffset('offset + k')});
158
           bestValue = ${reduceOps[reduceType]};
159
          }
160
          aBestValues[local_idx] = bestValue;
161
          workgroupBarrier();
162

163
         var reduceSize = min(Length, ${workgroupSize}u);
164
         for (var currentSize = reduceSize / 2u; reduceSize > 1u;
165
             currentSize = reduceSize / 2u) {
166
           let interval = DIV_CEIL(reduceSize, 2u);
167
           if (local_idx < currentSize) {
168
            let candidate = aBestValues[local_idx + interval];
169
            bestValue = ${reduceSharedOps[reduceType]};
170
            aBestValues[local_idx] = bestValue;
171
           }
172
           reduceSize = interval;
173
           workgroupBarrier();
174
         }
175

176
         if (local_idx == 0u) {
177
          ${output.setByOffset(
178
            'outputIndex',
179
            `${
180
              reduceType === 'mean'
181
                ? `${output.type.storage}(bestValue / f32(uniforms.reduceSize))`
182
                : `${output.type.storage}(${reduceOutputValues[reduceType]})`
183
            }`,
184
          )};
185
         }
186
        }`;
187

188
  // One work group is responsible for only one element of output.
189
  return {
190
    name,
191
    shaderCache,
192
    getShaderSource,
193
    getRunData: () => ({
194
      outputs: [{ dims: outputShape, dataType: outputDataType }],
195
      dispatchGroup: { x: outputSize },
196
      programUniforms: [{ type: DataType.uint32, data: reduceSize }],
197
    }),
198
  };
199
};
200

201
const reduceCommon = (
202
  context: ComputeContext,
203
  name: string,
204
  attributes: ReduceAttributes,
205
  reduceType: 'sum' | 'sumSquare' | 'prod' | 'min' | 'max' | 'mean' | 'logSumExp' | 'l1' | 'l2' | 'logSum',
206
): void => {
207
  const updatedAttributes: ReduceAttributes =
208
    context.inputs.length === 1 ? attributes : createReduceAttributesFromInputs(context.inputs, attributes);
209

210
  let updatedAxes = updatedAttributes.axes;
211
  if (updatedAxes.length === 0 && !updatedAttributes.noopWithEmptyAxes) {
212
    updatedAxes = context.inputs[0].dims.map((_dim, i) => i);
213
  }
214
  const normalizeAxes = ShapeUtil.normalizeAxes(updatedAxes, context.inputs[0].dims.length);
215

216
  let axes = normalizeAxes;
217
  let input = context.inputs[0];
218
  const permutedAxes = getAxesPermutation(axes, context.inputs[0].dims.length);
219
  if (permutedAxes.length > 0) {
220
    input = context.compute(createTransposeProgramInfo(context.inputs[0], permutedAxes), {
221
      inputs: [0],
222
      outputs: [-1],
223
    })[0];
224
    axes = getInnerMostAxes(axes.length, input.dims.length);
225
  }
226

227
  const [outputShape, reduceShape] = computeOutAndReduceShapes(input.dims, axes);
228
  let finalOutputShape = outputShape;
229
  if (updatedAttributes.keepDims) {
230
    finalOutputShape = expandShapeToKeepDim(outputShape, normalizeAxes);
231
  }
232

233
  context.compute(
234
    createReduceSharedProgramInfo(
235
      name,
236
      { hint: updatedAttributes.cacheKey, inputDependencies: ['type'] },
237
      [input],
238
      reduceType,
239
      context.inputs[0].dataType,
240
      finalOutputShape,
241
      reduceShape,
242
    ),
243
    { inputs: [input] },
244
  );
245
};
246

247
export const reduceMeanShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
248
  reduceCommon(context, 'ReduceMeanShared', attributes, 'mean');
249
};
250

251
export const reduceL1Shared = (context: ComputeContext, attributes: ReduceAttributes): void => {
252
  reduceCommon(context, 'ReduceL1Shared', attributes, 'l1');
253
};
254

255
export const reduceL2Shared = (context: ComputeContext, attributes: ReduceAttributes): void => {
256
  reduceCommon(context, 'ReduceL2Shared', attributes, 'l2');
257
};
258

259
export const reduceLogSumExpShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
260
  reduceCommon(context, 'ReduceLogSumExpShared', attributes, 'logSumExp');
261
};
262

263
export const reduceMaxShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
264
  reduceCommon(context, 'ReduceMaxShared', attributes, 'max');
265
};
266

267
export const reduceMinShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
268
  reduceCommon(context, 'ReduceMinShared', attributes, 'min');
269
};
270

271
export const reduceProdShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
272
  reduceCommon(context, 'ReduceProdShared', attributes, 'prod');
273
};
274

275
export const reduceSumShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
276
  reduceCommon(context, 'ReduceSumShared', attributes, 'sum');
277
};
278

279
export const reduceSumSquareShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
280
  reduceCommon(context, 'ReduceSumSquareShared', attributes, 'sumSquare');
281
};
282

283
export const reduceLogSumShared = (context: ComputeContext, attributes: ReduceAttributes): void => {
284
  reduceCommon(context, 'ReduceLogSumShared', attributes, 'logSum');
285
};
286

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.