onnxruntime

Форк
0
236 строк · 6.9 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { ShapeUtil } from '../../../util';
9
import { getGlsl, Glsl } from '../glsl-source';
10
import { WebGLInferenceHandler } from '../inference-handler';
11
import { ProgramInfo, TextureType } from '../types';
12

13
export interface PadAttributes extends AttributeWithCacheKey {
14
  readonly mode: string;
15
  readonly pads: number[];
16
  readonly value: number;
17
}
18

19
const padProgramMetadata = {
20
  name: 'Pad',
21
  inputNames: ['A'],
22
  inputTypes: [TextureType.unpacked],
23
};
24

25
export const padV2: OperatorImplementation<PadAttributes> = (
26
  inferenceHandler: WebGLInferenceHandler,
27
  inputs: Tensor[],
28
  attributes: PadAttributes,
29
): Tensor[] => {
30
  validateInputsV2(inputs);
31
  const output = inferenceHandler.run(
32
    {
33
      ...padProgramMetadata,
34
      cacheHint: attributes.cacheKey,
35
      get: () => createPadProgramInfo(inferenceHandler, inputs[0], attributes),
36
    },
37
    inputs,
38
  );
39
  return [output];
40
};
41

42
export const parsePadAttributesV2: OperatorInitialization<PadAttributes> = (node: Graph.Node): PadAttributes => {
43
  const mode = node.attributes.getString('mode', 'constant');
44
  const value = node.attributes.getFloat('value', 0.0);
45
  const pads = node.attributes.getInts('pads');
46
  return createAttributeWithCacheKey({ mode, value, pads });
47
};
48

49
export const padV11: OperatorImplementation<string> = (
50
  inferenceHandler: WebGLInferenceHandler,
51
  inputs: Tensor[],
52
  mode: string,
53
): Tensor[] => {
54
  validateInputsV11(inputs);
55
  const attrubutes = generatePadAttributesFromInputs(inferenceHandler, inputs, mode);
56
  return padV2(inferenceHandler, [inputs[0]], attrubutes);
57
};
58

59
export const parsePadAttributesV11: OperatorInitialization<string> = (node: Graph.Node): string =>
60
  node.attributes.getString('mode', 'constant');
61

62
const generatePadAttributesFromInputs = (
63
  inferenceHandler: WebGLInferenceHandler,
64
  inputs: Tensor[],
65
  mode: string,
66
): PadAttributes => {
67
  if (
68
    !inferenceHandler.session.isInitializer(inputs[1].dataId) ||
69
    (inputs.length >= 3 && !inferenceHandler.session.isInitializer(inputs[2].dataId))
70
  ) {
71
    throw new Error('dynamic pad attributes are not allowed');
72
  }
73

74
  const pads = Array.from(inputs[1].integerData);
75
  const value = inputs.length >= 3 ? inputs[2].floatData[0] : 0.0;
76

77
  return createAttributeWithCacheKey({ mode, pads, value });
78
};
79

80
const createPadProgramInfo = (
81
  inferenceHandler: WebGLInferenceHandler,
82
  input: Tensor,
83
  attributes: PadAttributes,
84
): ProgramInfo => {
85
  const outputShape = ShapeUtil.padShape(input.dims.slice(), attributes.pads);
86
  const rank = outputShape.length;
87
  const padFunction = getPadFunction(inferenceHandler, input, attributes);
88
  const shaderSource = `
89
      ${padFunction}
90
      float process(int[${rank}] indices) {
91
          return padA(indices);
92
      }`;
93
  return {
94
    name: 'Pad',
95
    inputNames: ['A'],
96
    inputTypes: [TextureType.unpacked],
97
    output: { dims: outputShape, type: input.type, textureType: TextureType.unpacked },
98
    shaderSource,
99
  };
100
};
101

102
const validateInputsV2 = (inputs: Tensor[]): void => {
103
  if (!inputs || inputs.length !== 1) {
104
    throw new Error('Pad requires 1 input');
105
  }
106
  if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') {
107
    throw new Error('Invalid input type.');
108
  }
109
};
110

111
const validateInputsV11 = (inputs: Tensor[]): void => {
112
  if (!inputs || (inputs.length !== 2 && inputs.length !== 3)) {
113
    throw new Error('Pad requires 2 or 3 inputs');
114
  }
115
  if (inputs[1].type !== 'int32') {
116
    throw new Error('Invalid input type.');
117
  }
118
  if (inputs.length >= 3 && inputs[2].type === 'string') {
119
    throw new Error('Invalid input type.');
120
  }
121
};
122

123
const getPadFunction = (inferenceHandler: WebGLInferenceHandler, input: Tensor, attributes: PadAttributes): string => {
124
  const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
125
  const [width, height] = inferenceHandler.calculateTextureWidthAndHeight(input.dims, TextureType.unpacked);
126
  const strides = ShapeUtil.computeStrides(input.dims);
127

128
  switch (attributes.mode) {
129
    case 'constant':
130
      return getPadConstant(glsl, input.dims, strides, width, height, attributes.pads, attributes.value);
131
    case 'reflect':
132
      return getPadReflect(glsl, input.dims, strides, width, height, attributes.pads);
133
    case 'edge':
134
      return getPadEdge(glsl, input.dims, strides, width, height, attributes.pads);
135
    default:
136
      throw new Error('Invalid mode');
137
  }
138
};
139

140
const getPadConstant = (
141
  glsl: Glsl,
142
  shape: readonly number[],
143
  strides: readonly number[],
144
  width: number,
145
  height: number,
146
  pads: number[],
147
  value: number,
148
): string => {
149
  const rank = shape.length;
150
  let block = '';
151
  for (let i = rank - 1; i >= 0; --i) {
152
    block += `
153
        k = m[${i}] - ${pads[i]};
154
        if (k < 0)  return constant;
155
        if (k >= ${shape[i]}) return constant;
156
        offset += k * ${strides[i]};
157
        `;
158
  }
159
  return `
160
      float padA(int m[${rank}]) {
161
        const float constant = float(${value});
162
        int offset = 0;
163
        int k = 0;
164
        ${block}
165
        vec2 coords = offsetToCoords(offset, ${width}, ${height});
166
        float value = getColorAsFloat(${glsl.texture2D}(A, coords));
167
        return value;
168
      }
169
      `;
170
};
171

172
const getPadReflect = (
173
  glsl: Glsl,
174
  shape: readonly number[],
175
  strides: readonly number[],
176
  width: number,
177
  height: number,
178
  pads: number[],
179
): string => {
180
  const rank = shape.length;
181

182
  let block = '';
183
  for (let i = rank - 1; i >= 0; --i) {
184
    block += `
185
        k = m[${i}] - ${pads[i]};
186
        if (k < 0) { k = -k; }
187
        {
188
          const int _2n_1 = ${2 * (shape[i] - 1)};
189
          k = int( mod( float(k), float(_2n_1) ) ) ;
190
          if(k >= ${shape[i]}) { k = _2n_1 - k; }
191
        }
192
        offset += k * ${strides[i]};
193
        `;
194
  }
195
  return `
196
      float padA(int m[${rank}]) {
197
        int offset = 0;
198
        int k = 0;
199
        ${block}
200
        vec2 coords = offsetToCoords(offset, ${width}, ${height});
201
        float value = getColorAsFloat(${glsl.texture2D}(A, coords));
202
        return value;
203
      }
204
      `;
205
};
206

207
const getPadEdge = (
208
  glsl: Glsl,
209
  shape: readonly number[],
210
  strides: readonly number[],
211
  width: number,
212
  height: number,
213
  pads: number[],
214
): string => {
215
  const rank = shape.length;
216

217
  let block = '';
218
  for (let i = rank - 1; i >= 0; --i) {
219
    block += `
220
        k = m[${i}] - ${pads[i]};
221
        if (k < 0)  k = 0;
222
        if (k >= ${shape[i]}) k = ${shape[i] - 1};
223
        offset += k * ${strides[i]};
224
      `;
225
  }
226
  return `
227
      float padA(int m[${rank}]) {
228
        int offset = 0;
229
        int k = 0;
230
        ${block}
231
        vec2 coords = offsetToCoords(offset, ${width}, ${height});
232
        float value = getColorAsFloat(${glsl.texture2D}(A, coords));
233
        return value;
234
      }
235
      `;
236
};
237

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.