onnxruntime

Форк
0
163 строки · 6.0 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';
9

10
import { createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper } from './common';
11

12
export interface ConcatAttributes extends AttributeWithCacheKey {
13
  readonly axis: number;
14
}
15

16
const validateInputs = (inputs: readonly TensorView[], axis: number): void => {
17
  if (!inputs || inputs.length < 1) {
18
    throw new Error('too few inputs');
19
  }
20
  const referenceIndex = 0;
21
  const referenceInput = inputs[referenceIndex];
22
  const inputType = referenceInput.dataType;
23
  const inputRank = referenceInput.dims.length;
24
  inputs.forEach((input, i) => {
25
    if (i === referenceIndex) {
26
      return;
27
    }
28
    // make sure types of all inputs match
29
    if (input.dataType !== inputType) {
30
      throw new Error('input tensors should be one type');
31
    }
32
    // make sure the dimensionality of all inputs are the same
33
    if (input.dims.length !== inputRank) {
34
      throw new Error('input tensors should have the same shape');
35
    }
36
    input.dims.forEach((dim, i) => {
37
      if (i !== axis && dim !== referenceInput.dims[i]) {
38
        throw new Error('non concat dimensions must match');
39
      }
40
    });
41
  });
42
};
43

44
const calculateInputIndexImpl = (numberOfTensors: number, sizeInConcatAxisStr: string): string => `
45
  fn calculateInputIndex(index: u32) -> u32 {
46
    let sizeInConcatAxis = array<u32, ${numberOfTensors}u>(${sizeInConcatAxisStr});
47
    for (var i: u32 = 0u; i < ${numberOfTensors}; i += 1u ) {
48
      if (index < sizeInConcatAxis[i]) {
49
        return i;
50
      }
51
    }
52
    return ${numberOfTensors}u;
53
  }`;
54

55
const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelper) => {
56
  const numberOfTensors = inputs.length;
57

58
  const codeLines: string[] = [];
59
  for (let i = 0; i < numberOfTensors; ++i) {
60
    const returnSnippet = output.setByOffset('global_idx', inputs[i].getByIndices('indices'));
61
    if (numberOfTensors === 1) {
62
      codeLines.push(returnSnippet);
63
    } else if (i === 0) {
64
      codeLines.push(`if (inputIndex == ${i}u) { ${returnSnippet} }`);
65
    } else if (i === numberOfTensors - 1) {
66
      codeLines.push(`else { ${returnSnippet} }`);
67
    } else {
68
      codeLines.push(`else if (inputIndex == ${i}) { ${returnSnippet} }`);
69
    }
70
  }
71
  return codeLines.join('\n');
72
};
73

74
const createConcatProgramInfo = (
75
  inputs: readonly TensorView[],
76
  adjustedAxis: number,
77
  outputShape: number[],
78
  dataType: DataType,
79
): ProgramInfo => {
80
  const outputSize = ShapeUtil.size(outputShape);
81

82
  const sizeInConcatAxis = new Array<number>(inputs.length);
83
  const inputVars = new Array<IndicesHelper>(inputs.length);
84

85
  let previousSum = 0;
86
  const inputDependencies: ProgramInputTensorInfoDependency[] = [];
87
  const inputRanks = [];
88
  const programUniforms: ProgramUniform[] = [{ type: DataType.uint32, data: outputSize }];
89
  for (let i = 0; i < inputs.length; ++i) {
90
    previousSum += inputs[i].dims[adjustedAxis];
91
    sizeInConcatAxis[i] = previousSum;
92
    inputRanks.push(inputs[i].dims.length);
93
    inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
94
    inputDependencies.push('rank');
95
    programUniforms.push({ type: DataType.uint32, data: sizeInConcatAxis[i] });
96
  }
97
  for (let i = 0; i < inputs.length; ++i) {
98
    programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
99
  }
100
  programUniforms.push(...createTensorShapeVariables(outputShape));
101

102
  const output = outputVariable('output', dataType, outputShape.length);
103
  const indicesAxis = output.indicesGet('indices', adjustedAxis);
104
  const sizeInConcatAxisStr = Array.from(Array(sizeInConcatAxis.length).keys())
105
    .map((i) => `uniforms.sizeInConcatAxis${i}`)
106
    .join(',');
107
  const getShaderSource = (shaderHelper: ShaderHelper) => `
108

109
  ${(() => {
110
    shaderHelper.registerUniform('outputSize', 'u32');
111
    for (let i = 0; i < inputs.length; i++) {
112
      shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32');
113
    }
114
    return shaderHelper.declareVariables(...inputVars, output);
115
  })()}
116

117
  ${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)}
118

119
  ${shaderHelper.mainStart()}
120
    ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
121

122
    var indices = ${output.offsetToIndices('global_idx')};
123

124
    let inputIndex = calculateInputIndex(${indicesAxis});
125
    if (inputIndex != 0u) {
126
      let sizeInConcatAxis = array<u32, ${sizeInConcatAxis.length}u>(${sizeInConcatAxisStr});
127
      ${indicesAxis} -= sizeInConcatAxis[inputIndex - 1u];
128
    }
129

130
    ${assignOutputData(inputVars, output)}
131
  }`;
132

133
  return {
134
    name: 'Concat',
135
    shaderCache: { hint: `${adjustedAxis}`, inputDependencies },
136
    getRunData: () => ({
137
      outputs: [{ dims: outputShape, dataType }],
138
      dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
139
      programUniforms,
140
    }),
141
    getShaderSource,
142
  };
143
};
144

145
export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => {
146
  const inputs = context.inputs;
147
  const inputShape = inputs[0].dims;
148
  const adjustedAxis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length);
149
  validateInputs(inputs, adjustedAxis);
150
  const outputShape = inputShape.slice();
151
  outputShape[adjustedAxis] = inputs.reduce(
152
    (sum, input) => sum + (input.dims.length > adjustedAxis ? input.dims[adjustedAxis] : 0),
153
    0,
154
  );
155
  // 0 length tensors are valid for concat, remove them
156
  const nonEmptyInputs = inputs.filter((input) => ShapeUtil.size(input.dims) > 0);
157
  context.compute(createConcatProgramInfo(nonEmptyInputs, adjustedAxis, outputShape, inputs[0].dataType), {
158
    inputs: nonEmptyInputs,
159
  });
160
};
161

162
export const parseConcatAttributes = (attributes: Record<string, unknown>): ConcatAttributes =>
163
  createAttributeWithCacheKey({ axis: attributes.axis as number });
164

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.