1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { Tensor } from '../../../tensor';
5
import { getGlsl } from '../glsl-source';
6
import { WebGLInferenceHandler } from '../inference-handler';
7
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
9
import { ConvAttributes } from './conv';
10
import { unpackFromChannel } from './packing-utils';
12
const createPackedIm2ColProgramMetadata = (cacheHint: string) => ({
13
name: 'Im2Col (packed)',
15
inputTypes: [TextureType.packed],
19
const createPackedIm2ColProgramInfo = (
20
inferenceHandler: WebGLInferenceHandler,
21
metadata: ProgramMetadata,
24
outputShape: readonly number[],
25
attributes: ConvAttributes,
27
const xshape = x.dims;
28
const wshape = w.dims;
31
const rank = outputShape.length;
32
const im2colShape = [wshape[1] * wshape[2] * wshape[3], outputShape[2] * outputShape[3]];
33
const kernelSize = wshape[2] * wshape[3];
34
const unpackChannel = unpackFromChannel();
35
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
38
for (let row = 0; row <= 1; row++) {
39
for (let col = 0; col <= 1; col++) {
41
blockIndex = rc.x + ${col};
44
if(blockIndex < ${im2colShape[1]} && pos < ${im2colShape[0]}) {
45
offsetY = int(blockIndex / (${outputShape[rank - 1]})) * ${attributes.strides[0]} -
46
${attributes.pads[0]};
47
d0 = offsetY + ${attributes.dilations[0]} * (imod(pos, ${kernelSize}) / ${wshape[2]});
49
if(d0 < ${xshape[rowDim]} && d0 >= 0) {
50
offsetX = imod(blockIndex, ${outputShape[rank - 1]}) * ${attributes.strides[1]} -
51
${attributes.pads[1]};
52
d1 = offsetX + ${attributes.dilations[1]} * imod(imod(pos, ${kernelSize}), ${wshape[2]});
54
if(d1 < ${xshape[colDim]} && d1 >= 0) {
56
ch = int(float(pos)/ ${kernelSize}.);
57
innerDims = vec2(d0, d1);
58
result[${row * 2 + col}] = getChannel(
59
getA(0, ch, int(innerDims.x),
60
int(innerDims.y)), innerDims);
69
const shaderSource = `
73
ivec2 rc = getOutputCoords();
74
vec4 result = vec4(0.0);
75
int blockIndex, pos, offsetY, d0, offsetX, d1, ch;
78
${glsl.output} = result;
83
output: { dims: im2colShape, type: x.type, textureType: TextureType.packed },
89
export const createPackedIm2ColProgramInfoLoader = (
90
inferenceHandler: WebGLInferenceHandler,
93
outputShape: readonly number[],
94
attributes: ConvAttributes,
95
): ProgramInfoLoader => {
96
const metadata = createPackedIm2ColProgramMetadata(attributes.cacheKey);
99
get: () => createPackedIm2ColProgramInfo(inferenceHandler, metadata, x, w, outputShape, attributes),