4
import { Tensor } from '../../../tensor';
5
import { ShapeUtil } from '../../../util';
6
import { getGlsl } from '../glsl-source';
7
import { WebGLInferenceHandler } from '../inference-handler';
8
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
10
import { unpackFromChannel } from './packing-utils';
12
const createPackedReshape3DProgramMetadata = (outputShape3D: readonly number[]) => ({
13
name: 'Reshape (packed)',
14
inputTypes: [TextureType.packed],
16
cacheHint: `${outputShape3D}`,
19
const createPackedReshape3DProgramInfo = (
20
handler: WebGLInferenceHandler,
22
metadata: ProgramMetadata,
23
outputShape3D: readonly number[],
25
const inputShape3D = input3D.dims as [number, number, number];
26
const squeezedOutputShape = outputShape3D as [number, number, number];
29
for (let i = 0; i < 4; i++) {
30
let outputCoords = '';
33
outputCoords = 'outputCoords = rc;';
36
outputCoords = 'outputCoords = ivec3(rc.x, rc.y+1, rc.z);';
39
outputCoords = 'outputCoords = ivec3(rc.x, rc.y, rc.z+1);';
42
outputCoords = 'outputCoords = ivec3(rc.x, rc.y+1, rc.z+1);';
50
${i > 0 ? 'if(outputCoords.y < rows && outputCoords.z < cols){' : ''}
51
int flattenedIndex = getFlattenedIndex(outputCoords);
53
ivec3 inputRC = inputCoordsFromReshapedOutCoords(flattenedIndex);
54
vec2 innerDims = vec2(float(inputRC.y),float(inputRC.z));
56
result[${i}] = getChannel(getA(inputRC.x, inputRC.y, inputRC.z), innerDims);
61
const glsl = getGlsl(handler.session.backend.glContext.version);
63
const shaderSource = `
64
${getReshapedInputCoords(inputShape3D)}
65
${getFlattenedIndexFrom3D(squeezedOutputShape)}
66
${unpackFromChannel()}
69
ivec3 rc = getOutputCoords();
71
vec4 result = vec4(0.0);
74
int rows = ${squeezedOutputShape[2]};
75
int cols = ${squeezedOutputShape[1]};
78
${glsl.output} = result;
84
output: { dims: squeezedOutputShape, type: input3D.type, textureType: TextureType.packed },
90
export const createPackedReshape3DProgramInfoLoader = (
91
handler: WebGLInferenceHandler,
93
outputShape3D: readonly number[],
94
): ProgramInfoLoader => {
95
const metadata = createPackedReshape3DProgramMetadata(outputShape3D);
96
return { ...metadata, get: () => createPackedReshape3DProgramInfo(handler, input3D, metadata, outputShape3D) };
99
export function processDims3D(shape: ArrayLike<number>): [number, number, number] {
100
if (shape.length === 0) {
105
for (let i = 0; i < shape.length - 2; ++i) {
108
return [batch, shape.length > 1 ? shape[shape.length - 2] : 1, shape[shape.length - 1]];
122
export function isReshapeCheap(dims: readonly number[], reshapedDims: readonly number[]) {
123
let isCheapReshape = false;
124
if (dims.length === 0 || reshapedDims.length === 0) {
126
isCheapReshape = true;
127
} else if (dims.length < 2 || reshapedDims.length < 2) {
129
isCheapReshape = dims[dims.length - 1] === reshapedDims[reshapedDims.length - 1];
133
dims[dims.length - 1] === reshapedDims[reshapedDims.length - 1] &&
134
dims[dims.length - 2] === reshapedDims[reshapedDims.length - 2];
137
return isCheapReshape;
140
function getReshapedInputCoords(shape: [number, number, number]): string {
141
const strides = ShapeUtil.computeStrides(shape);
142
const coords = ['b', 'r', 'c'];
143
const index = 'index';
144
const coordsFromIndexSnippet = strides
145
.map((stride, i) => {
146
const line1 = `int ${coords[i]} = ${index} / ${stride}`;
148
i === strides.length - 1
149
? `int ${coords[i + 1]} = ${index} - ${coords[i]} * ${stride}`
150
: `index -= ${coords[i]} * ${stride}`;
151
return `${line1}; ${line2};`;
156
ivec3 inputCoordsFromReshapedOutCoords(int index) {
157
${coordsFromIndexSnippet}
158
return ivec3(b, r, c);
163
function getFlattenedIndexFrom3D(shape: [number, number, number]): string {
164
const strides = ShapeUtil.computeStrides(shape);
167
int getFlattenedIndex(ivec3 coords) {
168
// reverse y, z order
169
return coords.x * ${strides[0]} + coords.z * ${strides[1]} + coords.y;