1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { Tensor } from '../../../tensor';
5
import { WebGLInferenceHandler } from '../inference-handler';
7
import { calculateOutputShape, ConvAttributes } from './conv';
8
import { createPackedIm2ColProgramInfoLoader } from './im2col-pack';
9
import { createPackedMatmulProgramInfoLoader } from './matmul-pack';
11
export const conv2DPackedPointwise = (
12
inferenceHandler: WebGLInferenceHandler,
13
inputs: readonly Tensor[],
14
attributes: ConvAttributes,
16
const xshape = inputs[0].dims;
17
const kshape = inputs[1].dims;
18
const outputShape = calculateOutputShape(xshape, kshape, attributes.dilations, attributes.pads, attributes.strides);
19
const reshapedX = inferenceHandler.reshapePacked(inputs[0], [xshape[1], xshape[2] * xshape[3]]);
20
const reshapedK = inferenceHandler.reshapePacked(inputs[1], [kshape[0], kshape[1]]);
22
const matmulInputs = inputs.length > 2 ? [reshapedK, reshapedX, inputs[2]] : [reshapedK, reshapedX];
23
const matmulOutput = inferenceHandler.run(
24
createPackedMatmulProgramInfoLoader(inferenceHandler, matmulInputs, attributes),
27
return inferenceHandler.reshapePacked(matmulOutput, outputShape);
30
export const conv2DPacked = (
31
inferenceHandler: WebGLInferenceHandler,
32
inputs: readonly Tensor[],
33
attributes: ConvAttributes,
35
const xshape = inputs[0].dims;
36
const kshape = inputs[1].dims;
37
const outputShape = calculateOutputShape(xshape, kshape, attributes.dilations, attributes.pads, attributes.strides);
40
const im2colOutput = inferenceHandler.run(
41
createPackedIm2ColProgramInfoLoader(inferenceHandler, inputs[0], inputs[1], outputShape, attributes),
46
const kernelReshaped = inferenceHandler.reshapePacked(inputs[1], [kshape[0], kshape[1] * kshape[2] * kshape[3]]);
49
const matmulInputs = inputs.length === 3 ? [kernelReshaped, im2colOutput, inputs[2]] : [kernelReshaped, im2colOutput];
50
const matmulOutput = inferenceHandler.run(
51
createPackedMatmulProgramInfoLoader(inferenceHandler, matmulInputs, attributes),
56
const outputReshaped = inferenceHandler.reshapePacked(matmulOutput, outputShape);
57
return outputReshaped;