onnxruntime

Форк
0
412 строк · 13.6 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, ProgramInfo, ProgramShaderCacheInfo } from '../types';
9

10
import { createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper } from './common';
11
import {
12
  reduceL1Shared,
13
  reduceL2Shared,
14
  reduceLogSumExpShared,
15
  reduceLogSumShared,
16
  reduceMaxShared,
17
  reduceMeanShared,
18
  reduceMinShared,
19
  reduceProdShared,
20
  reduceSumShared,
21
  reduceSumSquareShared,
22
} from './reduce-shared';
23

24
const validateInputs = (inputs: readonly TensorView[]): void => {
25
  if (!inputs || inputs.length === 0 || inputs.length > 2) {
26
    throw new Error('Reduce op requires 1 or 2 inputs.');
27
  }
28

29
  if (inputs.length === 2 && inputs[1].dims.length !== 1) {
30
    throw new Error('Invalid axes input dims.');
31
  }
32
};
33

34
export interface ReduceAttributes extends AttributeWithCacheKey {
35
  keepDims: boolean;
36
  noopWithEmptyAxes: boolean;
37
  axes: number[];
38
}
39

40
export type ReduceOp = (
41
  input: IndicesHelper,
42
  output: IndicesHelper,
43
  axes: readonly number[],
44
) => [string, string, string, string, ...string[]];
45

46
const noOp: ReduceOp = (input) => ['', '', `var value = ${input.getByIndices('input_indices')};`, ''];
47
export const createReduceProgramInfo = (
48
  name: string,
49
  shaderCache: ProgramShaderCacheInfo,
50
  inputs: readonly TensorView[],
51
  reduceOp: ReduceOp,
52
  axesInput: number[],
53
  outputDataType: DataType,
54
  keepDims = false,
55
  noopWithEmptyAxes = false,
56
): ProgramInfo => {
57
  const outputShape: number[] = [];
58
  const inputShape = inputs[0].dims;
59
  const inputRank = inputShape.length;
60
  const axes = ShapeUtil.normalizeAxes(axesInput, inputRank);
61
  const reduceOnAllAxes = !noopWithEmptyAxes && axes.length === 0;
62
  inputShape.forEach((d, i) => {
63
    if (reduceOnAllAxes || axes.indexOf(i) >= 0) {
64
      if (keepDims) {
65
        outputShape.push(1);
66
      } // else { // skip this axis}
67
    } else {
68
      outputShape.push(d);
69
    }
70
  });
71
  const outputRank = outputShape.length;
72
  const outputSize = ShapeUtil.size(outputShape);
73
  const getShaderSource = (shaderHelper: ShaderHelper) => {
74
    const idxCopy: string[] = []; // copy output indexes to input indexes
75

76
    const input = inputVariable('_A', inputs[0].dataType, inputRank);
77
    const output = outputVariable('output', outputDataType, outputRank);
78
    const ops = reduceOp(input, output, axes);
79
    let reduceOps = ops[2];
80

81
    for (let k = 0, l = 0; k < inputRank; k++) {
82
      // if this axis is reduced
83
      if (reduceOnAllAxes || axes.indexOf(k) >= 0) {
84
        if (keepDims) {
85
          l++;
86
        }
87
        // loop over the d-th axis
88
        reduceOps = `for(var j${k}: u32 = 0; j${k} < ${inputShape[k]}; j${k}++) {
89
                  ${ops[2].includes('last_index') ? `let last_index = j${k};` : ''}
90
                  ${input.indicesSet('input_indices', k, `j${k}`)}
91
                  ${reduceOps}
92
                }`;
93
      } else {
94
        idxCopy.push(`${input.indicesSet('input_indices', k, output.indicesGet('output_indices', l))};`);
95
        l++;
96
      }
97
    }
98
    return `
99

100
        ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
101

102
        ${shaderHelper.mainStart()}
103
          ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
104
          var input_indices: ${input.type.indices};
105
          let output_indices = ${output.offsetToIndices('global_idx')};
106

107
          ${idxCopy.join('\n')}
108
          ${ops[0]}       // init ops for reduce max/min
109
          ${ops[1]}
110
          ${reduceOps}
111
          ${ops[3]}
112
          ${ops.length === 4 ? output.setByOffset('global_idx', 'value') : ops.slice(4).join('\n')}
113
        }`;
114
  };
115

116
  return {
117
    name,
118
    shaderCache,
119
    getShaderSource,
120
    getRunData: () => ({
121
      outputs: [{ dims: outputShape, dataType: outputDataType }],
122
      dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
123
      programUniforms: [
124
        { type: DataType.uint32, data: outputSize },
125
        ...createTensorShapeVariables(inputShape, outputShape),
126
      ],
127
    }),
128
  };
129
};
130

131
export const createReduceAttributesFromInputs = (
132
  inputs: readonly TensorView[],
133
  attributes: ReduceAttributes,
134
): ReduceAttributes => {
135
  const axes: number[] = [];
136
  if (inputs[1].dims[0] > 0) {
137
    inputs[1].getBigInt64Array().forEach((v) => axes.push(Number(v)));
138
  }
139
  return createAttributeWithCacheKey({
140
    axes,
141
    keepDims: attributes.keepDims,
142
    noopWithEmptyAxes: attributes.noopWithEmptyAxes,
143
  });
144
};
145

146
const runReduceProgram = (
147
  context: ComputeContext,
148
  name: string,
149
  attributes: ReduceAttributes,
150
  reduceOp: ReduceOp,
151
): void => {
152
  const inputs = context.inputs;
153
  const updatedAttributes: ReduceAttributes =
154
    inputs.length === 1 ? attributes : createReduceAttributesFromInputs(inputs, attributes);
155

156
  context.compute(
157
    createReduceProgramInfo(
158
      name,
159
      { hint: updatedAttributes.cacheKey, inputDependencies: ['rank'] },
160
      [inputs[0]],
161
      updatedAttributes.noopWithEmptyAxes && updatedAttributes.axes.length === 0 ? noOp : reduceOp,
162
      updatedAttributes.axes,
163
      inputs[0].dataType,
164
      updatedAttributes.keepDims,
165
      updatedAttributes.noopWithEmptyAxes,
166
    ),
167
    { inputs: [0] },
168
  );
169
};
170

171
const reduceLogSumNaive = (context: ComputeContext, attributes: ReduceAttributes): void => {
172
  validateInputs(context.inputs);
173
  const reduceOp: ReduceOp = (input, output) => [
174
    `var value = ${output.type.storage}(0);`,
175
    '',
176
    `value += ${input.getByIndices('input_indices')};`,
177
    'value = log(value);',
178
  ];
179
  runReduceProgram(context, 'ReduceLogSum', attributes, reduceOp);
180
};
181

182
const reduceL1Naive = (context: ComputeContext, attributes: ReduceAttributes): void => {
183
  validateInputs(context.inputs);
184
  const reduceOp: ReduceOp = (input, output) => [
185
    `var value = ${output.type.storage}(0);`,
186
    '',
187
    `value += abs(${input.getByIndices('input_indices')});`,
188
    '',
189
  ];
190
  runReduceProgram(context, 'ReduceL1', attributes, reduceOp);
191
};
192

193
const reduceL2Naive = (context: ComputeContext, attributes: ReduceAttributes): void => {
194
  validateInputs(context.inputs);
195
  const reduceOp: ReduceOp = (input, output) => [
196
    `var t = ${output.type.value}(0); var value = ${output.type.value}(0);`,
197
    '',
198
    `t = ${input.getByIndices('input_indices')}; value += (t * t);`,
199
    'value = sqrt(value);',
200
  ];
201
  runReduceProgram(context, 'ReduceL2', attributes, reduceOp);
202
};
203

204
const reduceLogSumExpNaive = (context: ComputeContext, attributes: ReduceAttributes): void => {
205
  validateInputs(context.inputs);
206
  const reduceOp: ReduceOp = (input, output) => [
207
    `var value = ${output.type.storage}(0);`,
208
    '',
209
    `value += exp(${input.getByIndices('input_indices')});`,
210
    'value = log(value);',
211
  ];
212
  runReduceProgram(context, 'ReduceLogSumExp', attributes, reduceOp);
213
};
214

215
const reduceMaxNaive = (context: ComputeContext, attributes: ReduceAttributes): void => {
216
  validateInputs(context.inputs);
217
  const reduceOp: ReduceOp = (input, _output, axes) => {
218
    const idxZero = [];
219
    for (let k = 0; k < input.rank; k++) {
220
      if (axes.indexOf(k) >= 0 || axes.length === 0) {
221
        idxZero.push(input.indicesSet('input_indices', k, 0));
222
      }
223
    }
224

225
    return [
226
      `${idxZero.join('\n')}`,
227
      `var value = ${input.getByIndices('input_indices')};`,
228
      `value = max(value, ${input.getByIndices('input_indices')});`,
229
      '',
230
    ];
231
  };
232
  runReduceProgram(context, 'ReduceMax', attributes, reduceOp);
233
};
234

235
const reduceMeanNaive = (context: ComputeContext, attributes: ReduceAttributes): void => {
236
  validateInputs(context.inputs);
237
  const reduceOp: ReduceOp = (input, output, axes) => {
238
    let size = 1.0;
239
    for (let k = 0; k < input.rank; k++) {
240
      if (axes.indexOf(k) >= 0 || axes.length === 0) {
241
        // TODO: this depends on the input dims. If we want to use uniform, this need to be updated.
242
        size *= context.inputs[0].dims[k];
243
      }
244
    }
245

246
    return [
247
      'var sum = f32(0);',
248
      '',
249
      `sum += f32(${input.getByIndices('input_indices')});`,
250
      `let value = ${output.type.value}(sum / ${size});`,
251
    ];
252
  };
253
  runReduceProgram(context, 'ReduceMean', attributes, reduceOp);
254
};
255

256
const reduceMinNaive = (context: ComputeContext, attributes: ReduceAttributes): void => {
257
  validateInputs(context.inputs);
258
  const reduceOp: ReduceOp = (input, _output, axes) => {
259
    const idxZero = [];
260
    for (let k = 0; k < input.rank; k++) {
261
      if (axes.indexOf(k) >= 0 || axes.length === 0) {
262
        idxZero.push(`input_indices[${k}] = 0;`); // first element
263
      }
264
    }
265

266
    return [
267
      `${idxZero.join('\n')}`,
268
      `var value = ${input.getByIndices('input_indices')};`,
269
      `value = min(value, ${input.getByIndices('input_indices')});`,
270
      '',
271
    ];
272
  };
273
  runReduceProgram(context, 'ReduceMin', attributes, reduceOp);
274
};
275

276
const reduceProdNaive = (context: ComputeContext, attributes: ReduceAttributes): void => {
277
  validateInputs(context.inputs);
278
  const reduceOp: ReduceOp = (input, output) => [
279
    `var value = ${output.type.storage}(1);`,
280
    '',
281
    `value *= ${input.getByIndices('input_indices')};`,
282
    '',
283
  ];
284
  runReduceProgram(context, 'ReduceProd', attributes, reduceOp);
285
};
286

287
const reduceSumNaive = (context: ComputeContext, attributes: ReduceAttributes): void => {
288
  validateInputs(context.inputs);
289
  const reduceOp: ReduceOp = (input, output) => [
290
    `var value = ${output.type.storage}(0);`,
291
    '',
292
    `value += ${input.getByIndices('input_indices')};`,
293
    '',
294
  ];
295
  runReduceProgram(context, 'ReduceSum', attributes, reduceOp);
296
};
297

298
const reduceSumSquareNaive = (context: ComputeContext, attributes: ReduceAttributes): void => {
299
  validateInputs(context.inputs);
300
  const reduceOp: ReduceOp = (input, output) => [
301
    `var t = ${output.type.value}(0); var value = ${output.type.value}(0);`,
302
    '',
303
    `t = ${input.getByIndices('input_indices')}; value += t * t;`,
304
    '',
305
  ];
306
  runReduceProgram(context, 'ReduceSumSquare', attributes, reduceOp);
307
};
308

309
const useNaiveReduceMethod = (
310
  shape: readonly number[],
311
  axes: readonly number[],
312
  noopWithEmptyAxes: boolean,
313
): boolean => {
314
  if (axes.length === 0) {
315
    return noopWithEmptyAxes;
316
  }
317

318
  let outputSize = 1;
319
  let reduceSize = 1;
320
  for (let dim = 0; dim < axes.length; dim++) {
321
    if (axes.indexOf(dim) === -1) {
322
      outputSize *= shape[dim];
323
    } else {
324
      reduceSize *= shape[dim];
325
    }
326
  }
327

328
  // The condition data is very rough, although considering the count of Execution Unit (EU), the potential
329
  // work groups in a EU and the counts of loops in the naive and shared methods, also doing experiments
330
  // on some machines.
331
  return reduceSize < 32 && outputSize > 1024;
332
};
333

334
export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes): void => {
335
  if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
336
    reduceMeanNaive(context, attributes);
337
  } else {
338
    reduceMeanShared(context, attributes);
339
  }
340
};
341

342
export const reduceL1 = (context: ComputeContext, attributes: ReduceAttributes): void => {
343
  if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
344
    reduceL1Naive(context, attributes);
345
  } else {
346
    reduceL1Shared(context, attributes);
347
  }
348
};
349

350
export const reduceL2 = (context: ComputeContext, attributes: ReduceAttributes): void => {
351
  if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
352
    reduceL2Naive(context, attributes);
353
  } else {
354
    reduceL2Shared(context, attributes);
355
  }
356
};
357

358
export const reduceLogSumExp = (context: ComputeContext, attributes: ReduceAttributes): void => {
359
  if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
360
    reduceLogSumExpNaive(context, attributes);
361
  } else {
362
    reduceLogSumExpShared(context, attributes);
363
  }
364
};
365

366
export const reduceMax = (context: ComputeContext, attributes: ReduceAttributes): void => {
367
  if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
368
    reduceMaxNaive(context, attributes);
369
  } else {
370
    reduceMaxShared(context, attributes);
371
  }
372
};
373

374
export const reduceMin = (context: ComputeContext, attributes: ReduceAttributes): void => {
375
  if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
376
    reduceMinNaive(context, attributes);
377
  } else {
378
    reduceMinShared(context, attributes);
379
  }
380
};
381

382
export const reduceProd = (context: ComputeContext, attributes: ReduceAttributes): void => {
383
  if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
384
    reduceProdNaive(context, attributes);
385
  } else {
386
    reduceProdShared(context, attributes);
387
  }
388
};
389

390
export const reduceSum = (context: ComputeContext, attributes: ReduceAttributes): void => {
391
  if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
392
    reduceSumNaive(context, attributes);
393
  } else {
394
    reduceSumShared(context, attributes);
395
  }
396
};
397

398
export const reduceSumSquare = (context: ComputeContext, attributes: ReduceAttributes): void => {
399
  if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
400
    reduceSumSquareNaive(context, attributes);
401
  } else {
402
    reduceSumSquareShared(context, attributes);
403
  }
404
};
405

406
export const reduceLogSum = (context: ComputeContext, attributes: ReduceAttributes): void => {
407
  if (useNaiveReduceMethod(context.inputs[0].dims, attributes.axes, attributes.noopWithEmptyAxes)) {
408
    reduceLogSumNaive(context, attributes);
409
  } else {
410
    reduceLogSumShared(context, attributes);
411
  }
412
};
413

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.