onnxruntime

Форк
0
65 строк · 2.3 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { TensorView } from '../../tensor-view';
5
import { ShapeUtil } from '../../util';
6
import { ComputeContext, ProgramInfo } from '../types';
7

8
import { inputVariable, outputVariable, ShaderHelper } from './common';
9

10
const validateInputs = (inputs: readonly TensorView[]): void => {
11
  if (inputs[0].dims.length !== 3) {
12
    throw new Error('input should have 3 dimensions');
13
  }
14

15
  if (![320, 640, 1280].includes(inputs[0].dims[2])) {
16
    throw new Error('number of channels should be 320, 640 or 1280');
17
  }
18

19
  if (inputs[1].dims.length !== 1) {
20
    throw new Error('bias is expected to have 1 dimensions');
21
  }
22

23
  if (inputs[0].dims[2] !== inputs[1].dims[0]) {
24
    throw new Error('last dimension of input and bias are not the same');
25
  }
26
};
27

28
const createBiasAddProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => {
29
  const outputShape = inputs[0].dims;
30

31
  const channels = inputs[0].dims[2];
32
  // since channel number can be only 320/640/1280, it's always divisable by 4
33
  const outputSize = ShapeUtil.size(outputShape) / 4;
34

35
  const dataType = inputs[0].dataType;
36
  const input = inputVariable('input', dataType, outputShape, 4);
37
  const bias = inputVariable('bias', dataType, [channels], 4);
38
  const residual = inputVariable('residual', dataType, outputShape, 4);
39
  const output = outputVariable('output', dataType, outputShape, 4);
40

41
  const getShaderSource = (shaderHelper: ShaderHelper) => `
42
  const channels = ${channels}u / 4;
43
  ${shaderHelper.declareVariables(input, bias, residual, output)}
44

45
  ${shaderHelper.mainStart()}
46
    ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
47
    let value = ${input.getByOffset('global_idx')}
48
      + ${bias.getByOffset('global_idx % channels')} + ${residual.getByOffset('global_idx')};
49
    ${output.setByOffset('global_idx', 'value')}
50
  }`;
51

52
  return {
53
    name: 'BiasAdd',
54
    getRunData: () => ({
55
      outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
56
      dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
57
    }),
58
    getShaderSource,
59
  };
60
};
61

62
export const biasAdd = (context: ComputeContext): void => {
63
  validateInputs(context.inputs);
64
  context.compute(createBiasAddProgramInfo(context.inputs));
65
};
66

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.