onnxruntime

Форк
0
150 строк · 5.3 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { Tensor } from '../../../tensor';
5
import { getGlsl } from '../glsl-source';
6
import { WebGLInferenceHandler } from '../inference-handler';
7
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
8
import { getCoordsDataType, getGlChannels } from '../utils';
9

10
import { ConcatAttributes } from './concat';
11
import { getChannels, unpackFromChannel } from './packing-utils';
12

13
const createPackedConcatProgramMetadata = (inputCount: number, cacheHint: string) => ({
14
  name: 'Concat (packed)',
15
  inputNames: Array.from({ length: inputCount }, (_v, i) => `X${i}`),
16
  inputTypes: Array(inputCount).fill(TextureType.packed),
17
  cacheHint,
18
});
19

20
const createPackedConcatProgramInfo = (
21
  handler: WebGLInferenceHandler,
22
  metadata: ProgramMetadata,
23
  inputs: Tensor[],
24
  axis: number,
25
): ProgramInfo => {
26
  const inputShape = inputs[0].dims.slice();
27
  if (axis >= inputShape.length || axis < -1 * inputShape.length) {
28
    throw new Error("axis specified for concat doesn't match input dimensionality");
29
  }
30
  if (axis < 0) {
31
    axis = inputShape.length + axis;
32
  }
33
  // ensure all of the non-concatenated axes match each other
34
  // calculate the shape of the output tensor while we do that
35
  const outputShape = inputShape.slice(0);
36
  for (let i = 1; i < inputs.length; i++) {
37
    const dataNShape = inputs[i].dims.slice();
38
    for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) {
39
      // add to the placeholder for computing output shape
40
      if (axisIndex === axis) {
41
        outputShape[axis] += dataNShape[axisIndex];
42
      }
43
      // ensure all non-cancatenated axes match each other
44
      else if (inputShape[axisIndex] !== dataNShape[axisIndex]) {
45
        throw new Error('non concat dimensions must match');
46
      }
47
    }
48
  }
49

50
  const rank = outputShape.length;
51
  const coords = getChannels('coords', rank);
52
  const dtype = getCoordsDataType(rank);
53
  const unpackChannel = unpackFromChannel();
54

55
  const shapes = inputs.map((i) => i.dims);
56
  const channels = getGlChannels(rank);
57
  const offsets: number[] = new Array(shapes.length - 1);
58

59
  offsets[0] = shapes[0][axis];
60
  for (let i = 1; i < offsets.length; i++) {
61
    offsets[i] = offsets[i - 1] + shapes[i][axis];
62
  }
63

64
  const channel = channels[axis];
65
  const lastChannels = channels.slice(-2);
66
  const allChannels = channels.join();
67

68
  let getValueSnippet = `if (${channel} < ${offsets[0]}) {
69
        return getChannel(
70
            getX0(${allChannels}), vec2(${lastChannels.join()}));
71
        }`;
72
  for (let i = 1; i < offsets.length; i++) {
73
    const shift = offsets[i - 1];
74
    getValueSnippet += `
75
            if (${channel} < ${offsets[i]}  && ${channel} >= ${offsets[i - 1]}) {
76
              return getChannel(
77
                getX${i}(${getShiftedChannelsSnippet(channels, channel, shift)}),
78
                vec2(${getShiftedChannelsSnippet(lastChannels, channel, shift)}));
79
            }`;
80
  }
81
  const lastIndex = offsets.length;
82
  const shift = offsets[offsets.length - 1];
83
  getValueSnippet += `
84
            return getChannel(
85
              getX${lastIndex}(${getShiftedChannelsSnippet(channels, channel, shift)}),
86
              vec2(${getShiftedChannelsSnippet(lastChannels, channel, shift)}));`;
87

88
  const glsl = getGlsl(handler.session.backend.glContext.version);
89

90
  const shaderSource = `
91
          ${unpackChannel}
92
          float getValue(${channels.map((x) => 'int ' + x)}) {
93
            ${getValueSnippet}
94
          }
95

96
          void main() {
97
            ${dtype} coords = getOutputCoords();
98
            int lastDim = coords.${channels[rank - 1]};
99
            coords.${channels[rank - 1]} = coords.${channels[rank - 2]};
100
            coords.${channels[rank - 2]} = lastDim;
101

102
            vec4 result = vec4(getValue(${coords}), 0., 0., 0.);
103

104
            ${coords[rank - 1]} = ${coords[rank - 1]} + 1;
105
            if (${coords[rank - 1]} < ${outputShape[rank - 1]}) {
106
              result.g = getValue(${coords});
107
            }
108

109
            ${coords[rank - 2]} = ${coords[rank - 2]} + 1;
110
            if (${coords[rank - 2]} < ${outputShape[rank - 2]}) {
111
              result.a = getValue(${coords});
112
            }
113

114
            ${coords[rank - 1]} = ${coords[rank - 1]} - 1;
115
            if (${coords[rank - 2]} < ${outputShape[rank - 2]} &&
116
                ${coords[rank - 1]} < ${outputShape[rank - 1]}) {
117
              result.b = getValue(${coords});
118
            }
119
            ${glsl.output} = result;
120
          }
121
        `;
122

123
  return {
124
    ...metadata,
125
    output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.packed },
126
    shaderSource,
127
    hasMain: true,
128
  };
129
};
130

131
export const createPackedConcatProgramInfoLoader = (
132
  handler: WebGLInferenceHandler,
133
  inputs: Tensor[],
134
  attributes: ConcatAttributes,
135
): ProgramInfoLoader => {
136
  const metadata = createPackedConcatProgramMetadata(inputs.length, attributes.cacheKey);
137
  return { ...metadata, get: () => createPackedConcatProgramInfo(handler, metadata, inputs, attributes.axis) };
138
};
139

140
const getShiftedChannelsSnippet = (channels: string[], channel: string, shift: number): string => {
141
  const channelIdx = channels.indexOf(channel);
142
  const res = channels.map((c, idx) => {
143
    if (idx === channelIdx) {
144
      return `${c} - ${shift}`;
145
    } else {
146
      return c;
147
    }
148
  });
149
  return res.join();
150
};
151

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.