onnxruntime

Форк
0
146 строк · 5.6 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, ProgramInfo, ProgramUniform, TensorInfo } from '../types';
9

10
import {
11
  createTensorShapeVariables,
12
  getElementAt,
13
  IndicesHelper,
14
  inputVariable,
15
  outputVariable,
16
  ShaderHelper,
17
} from './common';
18

19
export interface SplitAttributes extends AttributeWithCacheKey {
20
  readonly axis: number;
21
  readonly numOutputs: number;
22
  readonly splitSizes: number[];
23
}
24

25
const validateInputs = (inputs: readonly TensorView[]): void => {
26
  if (!inputs || inputs.length < 1) {
27
    throw new Error('too few inputs');
28
  }
29
};
30

31
const createSplitAttributesFromInputs = (
32
  inputs: readonly TensorView[],
33
  attributes: SplitAttributes,
34
): SplitAttributes => {
35
  const splitSizes: number[] = [];
36
  let numOutputs: number = attributes.numOutputs;
37
  if (inputs[1].dims[0] > 0) {
38
    inputs[1].getBigInt64Array().forEach((v) => splitSizes.push(Number(v)));
39
    numOutputs = splitSizes.length;
40
  }
41
  return createAttributeWithCacheKey({ numOutputs, axis: attributes.axis, splitSizes });
42
};
43

44
const calculateOutputIndexImpl = (numberOfTensors: number): string => `
45
fn calculateOutputIndex(index: u32) -> u32 {
46
    for (var i: u32 = 0u; i < ${numberOfTensors}u; i += 1u ) {
47
    if (index < ${getElementAt('uniforms.size_in_split_axis', 'i', numberOfTensors)}) {
48
        return i;
49
    }
50
    }
51
    return ${numberOfTensors}u;
52
}`;
53
const writeBufferDataImpl = (outputs: readonly IndicesHelper[]) => {
54
  const numberOfTensors = outputs.length;
55
  const codeLines: string[] = [];
56
  for (let i = 0; i < numberOfTensors; ++i) {
57
    const returnSnippet = outputs[i].setByIndices('indices', 'input[global_idx]');
58
    if (numberOfTensors === 1) {
59
      codeLines.push(returnSnippet);
60
    } else if (i === 0) {
61
      codeLines.push(`if (output_number == ${i}u) { ${returnSnippet} }`);
62
    } else if (i === numberOfTensors - 1) {
63
      codeLines.push(`else { ${returnSnippet} }`);
64
    } else {
65
      codeLines.push(`else if (output_number == ${i}) { ${returnSnippet} }`);
66
    }
67
  }
68
  return `
69
      fn writeBufferData(output_number: u32, indices: ${outputs[0].type.indices}, global_idx: u32) {
70
        ${codeLines.join('\n')}
71
      }`;
72
};
73

74
const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: SplitAttributes): ProgramInfo => {
75
  const inputShape = inputs[0].dims;
76
  const inputSize = ShapeUtil.size(inputShape);
77
  const dataType = inputs[0].dataType;
78
  const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length);
79
  const outputs = new Array<IndicesHelper>(attributes.numOutputs);
80
  const input = inputVariable('input', dataType, inputShape.length);
81
  const sizeInSplitAxis = new Array<number>(attributes.numOutputs);
82
  const outputsTensorInfo: TensorInfo[] = [];
83
  const outputShapes: number[][] = [];
84
  let previousSum = 0;
85
  const programUniforms: ProgramUniform[] = [{ type: DataType.uint32, data: inputSize }];
86
  for (let i = 0; i < attributes.numOutputs; i++) {
87
    previousSum += attributes.splitSizes[i];
88
    sizeInSplitAxis[i] = previousSum;
89
    const outputShape = inputShape.slice();
90
    outputShape[axis] = attributes.splitSizes[i];
91
    outputShapes.push(outputShape);
92
    outputs[i] = outputVariable(`output${i}`, dataType, outputShape.length);
93
    outputsTensorInfo.push({ dims: outputShapes[i], dataType: inputs[0].dataType });
94
  }
95
  programUniforms.push(
96
    { type: DataType.uint32, data: sizeInSplitAxis },
97
    ...createTensorShapeVariables(inputShape, ...outputShapes),
98
  );
99
  const getShaderSource = (shaderHelper: ShaderHelper) => `
100
  ${shaderHelper
101
    .registerUniform('input_size', 'u32')
102
    .registerUniform('size_in_split_axis', 'u32', sizeInSplitAxis.length)
103
    .declareVariables(input, ...outputs)}
104
  ${calculateOutputIndexImpl(sizeInSplitAxis.length)}
105
  ${writeBufferDataImpl(outputs)}
106

107
  ${shaderHelper.mainStart()}
108
    ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.input_size')}
109

110
    var indices = ${input.offsetToIndices('global_idx')};
111
    var index = ${input.indicesGet('indices', axis)};
112
    let output_number = calculateOutputIndex(index);
113
    if (output_number != 0) {
114
      index -= ${getElementAt('uniforms.size_in_split_axis', 'output_number - 1u', sizeInSplitAxis.length)};
115
      ${input.indicesSet('indices', axis, 'index')};
116
    }
117
    writeBufferData(output_number, indices, global_idx);
118
  }`;
119
  return {
120
    name: 'Split',
121
    shaderCache: { hint: attributes.cacheKey, inputDependencies: ['rank'] },
122
    getShaderSource,
123
    getRunData: () => ({
124
      outputs: outputsTensorInfo,
125
      dispatchGroup: { x: Math.ceil(inputSize / 64 /* workgroup size */) },
126
      programUniforms,
127
    }),
128
  };
129
};
130

131
export const split = (context: ComputeContext, attributes: SplitAttributes): void => {
132
  validateInputs(context.inputs);
133
  const updatedAttributes =
134
    context.inputs.length === 1 ? attributes : createSplitAttributesFromInputs(context.inputs, attributes);
135
  context.compute(createSplitProgramInfo(context.inputs, updatedAttributes), { inputs: [0] });
136
};
137

138
export const parseSplitAttributes = (attributes: Record<string, unknown>): SplitAttributes => {
139
  const axis = attributes.axis as number;
140
  const splitSizes: number[] = attributes.splitSizes as number[];
141
  const numOutputs = (attributes.numOutputs as number) < 0 ? splitSizes.length : (attributes.numOutputs as number);
142
  if (numOutputs !== splitSizes.length) {
143
    throw new Error('numOutputs and splitSizes lengh must be equal');
144
  }
145
  return createAttributeWithCacheKey({ axis, numOutputs, splitSizes });
146
};
147

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.