1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { Tensor } from '../../../tensor';
5
import { WebGLInferenceHandler } from '../inference-handler';
6
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
8
import { ConvAttributes } from './conv';
10
const createIm2ColProgramMetadata = (cacheHint: string) => ({
13
inputTypes: [TextureType.unpacked],
17
const createIm2ColProgramInfo = (
18
_inferenceHandler: WebGLInferenceHandler,
19
metadata: ProgramMetadata,
22
outputShape: readonly number[],
23
attributes: ConvAttributes,
25
const xshape = x.dims;
26
const wshape = w.dims;
28
const rank = outputShape.length;
29
const im2colDims = calculateIm2ColDims(xshape, wshape, outputShape, 4);
31
const shaderSource = `
32
const int XC = ${xshape[1]};
33
const int XH = ${xshape[2]};
34
const int XW = ${xshape[3]};
35
const int KH = ${attributes.kernelShape[0]};
36
const int KW = ${attributes.kernelShape[1]};
37
const int dilationH = ${attributes.dilations[0]};
38
const int dilationW = ${attributes.dilations[1]};
39
const int strideH = ${attributes.strides[0]};
40
const int strideW = ${attributes.strides[1]};
41
const int padH = ${attributes.pads[0]};
42
const int padW = ${attributes.pads[1]};
43
const int KHKW = KH*KW;
44
const int XCKHKW = XC * KHKW;
45
const int outputChannels = 4;
46
vec4 process(int indices[${rank}]) {
47
int b = indices[0]; // batch size
48
int oh = indices[1] * strideH - padH; //output height
49
int ow = indices[2] * strideW - padW; //output width
50
int p = indices[3] * outputChannels; //patch
51
vec4 value = vec4(0.0);
52
for(int i=0; i < outputChannels; ++i) {
54
int patchC = p / KHKW;
55
int patchH = (p - patchC*KHKW) / KW;
56
int patchW = (p - patchC*KHKW) - patchH * KW;
57
int xh2 = oh + patchH * dilationH;
58
int xw2 = ow + patchW * dilationW;
59
int x[${xshape.length}];
78
output: { dims: im2colDims, type: x.type, textureType: TextureType.packedLastDimension },
83
export const createIm2ColProgramInfoLoader = (
84
inferenceHandler: WebGLInferenceHandler,
87
outputShape: readonly number[],
88
attributes: ConvAttributes,
89
): ProgramInfoLoader => {
90
const metadata = createIm2ColProgramMetadata(attributes.cacheKey);
93
get: () => createIm2ColProgramInfo(inferenceHandler, metadata, x, w, outputShape, attributes),
97
export const calculateIm2ColDims = (
98
inputShape: readonly number[],
99
kernelShape: readonly number[],
100
outputShape: readonly number[],
106
Math.ceil((inputShape[1] * kernelShape[2] * kernelShape[3]) / channels),