1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { Tensor } from '../../../tensor';
5
import { getGlsl } from '../glsl-source';
6
import { WebGLInferenceHandler } from '../inference-handler';
7
import { ProgramInfo, ProgramInfoLoader, TextureType } from '../types';
8
import { getCoordsDataType } from '../utils';
10
import { getChannels } from './packing-utils';
12
const packProgramMetadata = {
15
inputTypes: [TextureType.unpackedReversed],
18
const createPackProgramInfo = (handler: WebGLInferenceHandler, input: Tensor): ProgramInfo => {
19
const glsl = getGlsl(handler.session.backend.glContext.version);
20
const inputShape = input.dims;
22
const inputRank = inputShape.length;
23
// createTextureLayoutFromShape won't change output rank. Need to verify by running tests
24
const outputRank = input.dims.length;
26
const coordsDataType = getCoordsDataType(outputRank);
27
const channels = getChannels('rc', outputRank);
28
const setup = getSetup(outputRank, channels, inputShape[inputShape.length - 2], inputShape[inputShape.length - 1]);
31
if (inputRank === 0) {
32
reversedInputWH = [1, 1];
33
} else if (inputRank === 1) {
34
reversedInputWH = [inputShape[0], 1];
36
reversedInputWH = [inputShape[outputRank - 1], inputShape[outputRank - 2]];
38
const outOfBoundsCondition = getOutOfBoundsCondition(outputRank, reversedInputWH, channels);
39
const output = getOutput(inputShape, channels);
41
const shaderSource = `
43
${coordsDataType} rc = getOutputCoords();
45
if(${outOfBoundsCondition}) {
46
${glsl.output} = vec4(0);
50
${glsl.output} = vec4(${output});
55
...packProgramMetadata,
57
output: { dims: input.dims, type: input.type, textureType: TextureType.packed },
62
export const createPackProgramInfoLoader = (handler: WebGLInferenceHandler, input: Tensor): ProgramInfoLoader => ({
63
...packProgramMetadata,
64
get: () => createPackProgramInfo(handler, input),
68
* check output coordinate location and return false if it is outside input's width/height boundary
70
function getOutOfBoundsCondition(rank: number, shape: readonly number[], dims: string[]): string {
75
return `rc > ${shape[0]}`;
79
for (let i = rank - 2; i < rank; i++) {
80
cond += `${dims[i]} >= ${shape[i - rank + 2]}`;
90
* code snippet to sample input texture with output coordinates
92
function getOutput(shape: readonly number[], dims: string[]): string {
93
const rank = shape.length;
96
return 'getA(), 0, 0, 0';
101
rc + 1 >= ${shape[0]} ? 0. : getA(rc + 1),
105
const coord00 = 'r, c';
106
const coord01 = 'r, cp1';
107
const coord10 = 'rp1, c';
108
const coord11 = 'rp1, cp1';
111
for (let i = 0; i < rank - 2; ++i) {
112
D = D + `${dims[i]},`;
115
return `getA(${D}${coord00}),
116
rEdge ? 0. : getA(${D}${coord10}),
117
cEdge ? 0. : getA(${D}${coord01}),
118
rEdge || cEdge ? 0. : getA(${D}${coord11})`;
122
* code snippet to setup 4 coordinates and edge conditions
124
function getSetup(rank: number, dims: string[], rows: number, cols: number): string {
125
if (rank === 0 || rank === 1) {
128
// rank >= 2 for width+height pack.
131
int r = ${dims[rank - 2]};
132
int c = ${dims[rank - 1]};
133
int rp1 = ${dims[rank - 2]} + 1;
134
int cp1 = ${dims[rank - 1]} + 1;
135
bool rEdge = rp1 >= ${cols};
136
bool cEdge = cp1 >= ${rows};