4
import { Tensor } from '../../../tensor';
5
import { getGlsl } from '../glsl-source';
6
import { WebGLInferenceHandler } from '../inference-handler';
7
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
8
import { getCoordsDataType, getGlChannels } from '../utils';
10
import { ConcatAttributes } from './concat';
11
import { getChannels, unpackFromChannel } from './packing-utils';
13
const createPackedConcatProgramMetadata = (inputCount: number, cacheHint: string) => ({
14
name: 'Concat (packed)',
15
inputNames: Array.from({ length: inputCount }, (_v, i) => `X${i}`),
16
inputTypes: Array(inputCount).fill(TextureType.packed),
20
const createPackedConcatProgramInfo = (
21
handler: WebGLInferenceHandler,
22
metadata: ProgramMetadata,
26
const inputShape = inputs[0].dims.slice();
27
if (axis >= inputShape.length || axis < -1 * inputShape.length) {
28
throw new Error("axis specified for concat doesn't match input dimensionality");
31
axis = inputShape.length + axis;
35
const outputShape = inputShape.slice(0);
36
for (let i = 1; i < inputs.length; i++) {
37
const dataNShape = inputs[i].dims.slice();
38
for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) {
40
if (axisIndex === axis) {
41
outputShape[axis] += dataNShape[axisIndex];
44
else if (inputShape[axisIndex] !== dataNShape[axisIndex]) {
45
throw new Error('non concat dimensions must match');
50
const rank = outputShape.length;
51
const coords = getChannels('coords', rank);
52
const dtype = getCoordsDataType(rank);
53
const unpackChannel = unpackFromChannel();
55
const shapes = inputs.map((i) => i.dims);
56
const channels = getGlChannels(rank);
57
const offsets: number[] = new Array(shapes.length - 1);
59
offsets[0] = shapes[0][axis];
60
for (let i = 1; i < offsets.length; i++) {
61
offsets[i] = offsets[i - 1] + shapes[i][axis];
64
const channel = channels[axis];
65
const lastChannels = channels.slice(-2);
66
const allChannels = channels.join();
68
let getValueSnippet = `if (${channel} < ${offsets[0]}) {
70
getX0(${allChannels}), vec2(${lastChannels.join()}));
72
for (let i = 1; i < offsets.length; i++) {
73
const shift = offsets[i - 1];
75
if (${channel} < ${offsets[i]} && ${channel} >= ${offsets[i - 1]}) {
77
getX${i}(${getShiftedChannelsSnippet(channels, channel, shift)}),
78
vec2(${getShiftedChannelsSnippet(lastChannels, channel, shift)}));
81
const lastIndex = offsets.length;
82
const shift = offsets[offsets.length - 1];
85
getX${lastIndex}(${getShiftedChannelsSnippet(channels, channel, shift)}),
86
vec2(${getShiftedChannelsSnippet(lastChannels, channel, shift)}));`;
88
const glsl = getGlsl(handler.session.backend.glContext.version);
90
const shaderSource = `
92
float getValue(${channels.map((x) => 'int ' + x)}) {
97
${dtype} coords = getOutputCoords();
98
int lastDim = coords.${channels[rank - 1]};
99
coords.${channels[rank - 1]} = coords.${channels[rank - 2]};
100
coords.${channels[rank - 2]} = lastDim;
102
vec4 result = vec4(getValue(${coords}), 0., 0., 0.);
104
${coords[rank - 1]} = ${coords[rank - 1]} + 1;
105
if (${coords[rank - 1]} < ${outputShape[rank - 1]}) {
106
result.g = getValue(${coords});
109
${coords[rank - 2]} = ${coords[rank - 2]} + 1;
110
if (${coords[rank - 2]} < ${outputShape[rank - 2]}) {
111
result.a = getValue(${coords});
114
${coords[rank - 1]} = ${coords[rank - 1]} - 1;
115
if (${coords[rank - 2]} < ${outputShape[rank - 2]} &&
116
${coords[rank - 1]} < ${outputShape[rank - 1]}) {
117
result.b = getValue(${coords});
119
${glsl.output} = result;
125
output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.packed },
131
export const createPackedConcatProgramInfoLoader = (
132
handler: WebGLInferenceHandler,
134
attributes: ConcatAttributes,
135
): ProgramInfoLoader => {
136
const metadata = createPackedConcatProgramMetadata(inputs.length, attributes.cacheKey);
137
return { ...metadata, get: () => createPackedConcatProgramInfo(handler, metadata, inputs, attributes.axis) };
140
const getShiftedChannelsSnippet = (channels: string[], channel: string, shift: number): string => {
141
const channelIdx = channels.indexOf(channel);
142
const res = channels.map((c, idx) => {
143
if (idx === channelIdx) {
144
return `${c} - ${shift}`;