onnxruntime

Форк
0
79 строк · 3.6 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, ProgramInfo } from '../types';
9

10
import { createTensorShapeVariables, getElementAt, inputVariable, outputVariable, ShaderHelper } from './common';
11

12
export interface CumSumAttributes extends AttributeWithCacheKey {
13
  readonly exclusive: boolean;
14
  readonly reverse: boolean;
15
}
16
const createCumsumProgramInfo = (
17
  inputType: number,
18
  inputShape: readonly number[],
19
  axisInput: TensorView,
20
  attributes: CumSumAttributes,
21
): ProgramInfo => {
22
  const outputSize = ShapeUtil.size(inputShape); // outputShape is same as inputShape.
23
  const rank = inputShape.length; // input/output rank
24
  const input = inputVariable('input', inputType, rank);
25
  const output = outputVariable('output', inputType, rank);
26
  const axisValue =
27
    axisInput.dataType === DataType.int32 ? axisInput.getInt32Array()[0] : Number(axisInput.getBigInt64Array()[0]);
28
  const axis = ShapeUtil.normalizeAxis(axisValue, rank);
29
  const getShaderSource = (shaderHelper: ShaderHelper) => {
30
    const index = ` i32(${input.indicesGet('inputIndices', 'uniforms.axis')}) `;
31
    const max = getElementAt('uniforms.input_shape', 'uniforms.axis', rank);
32
    const lowerLimit = attributes.reverse ? index + (attributes.exclusive ? ' + 1' : '') : '0';
33
    const upperLimit = attributes.reverse ? max : index + (attributes.exclusive ? '' : ' + 1');
34
    return `
35
                ${shaderHelper
36
                  .registerUniform('outputSize', 'u32')
37
                  .registerUniform('axis', 'u32')
38
                  .declareVariables(input, output)}
39
                ${shaderHelper.mainStart()}
40
                  ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
41
                  var inputIndices = ${output.offsetToIndices('global_idx')};
42
                  var sum = ${output.type.value}(0);
43
                  let first : i32 = ${lowerLimit};
44
                  let last : i32 = ${upperLimit};
45
                  for (var i : i32 = first; i < last; i++) {
46
                    ${input.indicesSet('inputIndices', 'uniforms.axis', 'u32(i)')};
47
                    sum = sum + ${input.getByIndices('inputIndices')};
48
                  }
49
                  ${output.setByOffset('global_idx', 'sum')};
50
                }`;
51
  };
52
  return {
53
    name: 'CumSum',
54
    shaderCache: { hint: attributes.cacheKey, inputDependencies: ['rank'] },
55
    getRunData: () => ({
56
      outputs: [{ dims: inputShape, dataType: inputType }],
57
      dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
58
      programUniforms: [
59
        { type: DataType.uint32, data: outputSize },
60
        { type: DataType.uint32, data: axis },
61
        ...createTensorShapeVariables(inputShape, inputShape),
62
      ],
63
    }),
64
    getShaderSource,
65
  };
66
};
67

68
export const cumsum = (context: ComputeContext, attributes: CumSumAttributes): void => {
69
  const inputShape = context.inputs[0].dims;
70
  const inputType = context.inputs[0].dataType;
71
  const axis = context.inputs[1];
72
  context.compute(createCumsumProgramInfo(inputType, inputShape, axis, attributes), { inputs: [0] });
73
};
74

75
export const parseCumSumAttributes = (attributes: Record<string, unknown>): CumSumAttributes => {
76
  const exclusive = (attributes.exclusive as number) === 1;
77
  const reverse = (attributes.reverse as number) === 1;
78
  return createAttributeWithCacheKey({ exclusive, reverse });
79
};
80

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.