4
import { TensorView } from '../../tensor-view';
5
import { ShapeUtil } from '../../util';
6
import { ComputeContext, ProgramInfo } from '../types';
8
import { inputVariable, outputVariable, ShaderHelper } from './common';
10
const validateInputs = (inputs: readonly TensorView[]): void => {
11
if (inputs[0].dims.length !== 3) {
12
throw new Error('input should have 3 dimensions');
15
if (![320, 640, 1280].includes(inputs[0].dims[2])) {
16
throw new Error('number of channels should be 320, 640 or 1280');
19
if (inputs[1].dims.length !== 1) {
20
throw new Error('bias is expected to have 1 dimensions');
23
if (inputs[0].dims[2] !== inputs[1].dims[0]) {
24
throw new Error('last dimension of input and bias are not the same');
28
const createBiasAddProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => {
29
const outputShape = inputs[0].dims;
31
const channels = inputs[0].dims[2];
33
const outputSize = ShapeUtil.size(outputShape) / 4;
35
const dataType = inputs[0].dataType;
36
const input = inputVariable('input', dataType, outputShape, 4);
37
const bias = inputVariable('bias', dataType, [channels], 4);
38
const residual = inputVariable('residual', dataType, outputShape, 4);
39
const output = outputVariable('output', dataType, outputShape, 4);
41
const getShaderSource = (shaderHelper: ShaderHelper) => `
42
const channels = ${channels}u / 4;
43
${shaderHelper.declareVariables(input, bias, residual, output)}
45
${shaderHelper.mainStart()}
46
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
47
let value = ${input.getByOffset('global_idx')}
48
+ ${bias.getByOffset('global_idx % channels')} + ${residual.getByOffset('global_idx')};
49
${output.setByOffset('global_idx', 'value')}
55
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
56
dispatchGroup: { x: Math.ceil(outputSize / 64 ) },
62
export const biasAdd = (context: ComputeContext): void => {
63
validateInputs(context.inputs);
64
context.compute(createBiasAddProgramInfo(context.inputs));