1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { Tensor } from '../../../tensor';
5
import { getGlsl } from '../glsl-source';
6
import { WebGLInferenceHandler } from '../inference-handler';
7
import { ProgramInfo, ProgramInfoLoader, TextureType } from '../types';
8
import { getCoordsDataType } from '../utils';
10
import { getChannels, unpackFromChannel } from './packing-utils';
12
const unpackProgramMetadata = {
15
inputTypes: [TextureType.packed],
18
export const createUnpackProgramInfo = (handler: WebGLInferenceHandler, input: Tensor): ProgramInfo => {
19
const rank = input.dims.length;
21
const channels = getChannels('rc', rank);
22
const innerDims = channels.slice(-2);
23
const coordsDataType = getCoordsDataType(rank);
24
const unpackChannel = unpackFromChannel();
25
const isScalar = input.dims.length === 0;
26
const sourceCoords = isScalar ? '' : getSourceCoords(rank, channels);
27
const coords = rank <= 1 ? 'rc' : `vec2(${innerDims.join(',')})`;
28
const glsl = getGlsl(handler.session.backend.glContext.version);
29
const shaderSource = `
32
${coordsDataType} rc = getOutputCoords();
34
// Sample the texture with the coords to get the rgba channel value.
35
vec4 packedInput = getA(${sourceCoords});
37
${glsl.output} = vec4(getChannel(packedInput, ${coords}), 0, 0, 0);
42
...unpackProgramMetadata,
44
output: { dims: input.dims, type: input.type, textureType: TextureType.unpacked },
49
export const createUnpackProgramInfoLoader = (handler: WebGLInferenceHandler, input: Tensor): ProgramInfoLoader => ({
50
...unpackProgramMetadata,
51
get: () => createUnpackProgramInfo(handler, input),
54
function getSourceCoords(rank: number, dims: string[]): string {
60
for (let i = 0; i < rank; i++) {