onnxruntime

Форк
0
433 строки · 13.2 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { PoolConvUtil, ShapeUtil } from '../../../util';
9
import { WebGLInferenceHandler } from '../inference-handler';
10
import { ProgramInfo, ProgramMetadata, TextureType } from '../types';
11

12
export interface AveragePoolAttributes extends AttributeWithCacheKey {
13
  readonly autoPad: string;
14
  readonly ceilMode: number;
15
  readonly countIncludePad: boolean;
16
  readonly kernelShape: readonly number[];
17
  readonly strides: readonly number[];
18
  readonly pads: readonly number[];
19
}
20

21
export const averagePool: OperatorImplementation<AveragePoolAttributes> = (
22
  inferenceHandler: WebGLInferenceHandler,
23
  inputs: Tensor[],
24
  attributes: AveragePoolAttributes,
25
): Tensor[] => {
26
  validateInputs(inputs);
27
  const metadata = {
28
    name: 'AveragePool',
29
    inputNames: ['X'],
30
    inputTypes: [TextureType.unpacked],
31
    cacheHint: attributes.cacheKey,
32
  };
33
  const output = inferenceHandler.run(
34
    { ...metadata, get: () => createAveragePoolProgramInfo(inputs, metadata, false, attributes) },
35
    inputs,
36
  );
37
  return [output];
38
};
39

40
export const parseAveragePoolAttributes: OperatorInitialization<AveragePoolAttributes> = (
41
  node: Graph.Node,
42
): AveragePoolAttributes => {
43
  const autoPad = node.attributes.getString('auto_pad', 'NOTSET');
44
  const ceilMode = node.attributes.getInt('ceil_mode', 0);
45
  const countIncludePad = node.attributes.getInt('count_include_pad', 0) === 0 ? false : true;
46
  const kernelShape = node.attributes.getInts('kernel_shape');
47
  const strides = node.attributes.getInts('strides', []);
48
  const pads = node.attributes.getInts('pads', []);
49

50
  // TODO: support attribute 'ceil_mode'
51
  if (ceilMode !== 0) {
52
    throw new Error('using ceil() in shape computation is not yet supported for AveragePool');
53
  }
54

55
  return createAttributeWithCacheKey({ autoPad, ceilMode, countIncludePad, kernelShape, strides, pads });
56
};
57

58
const createAveragePoolProgramInfo = (
59
  inputs: Tensor[],
60
  metadata: ProgramMetadata,
61
  isGlobalOperator: boolean,
62
  attributes: AveragePoolAttributes,
63
): ProgramInfo => {
64
  const [adjustedAttributes, outputShape] = getAdjustedPoolAttributesAndOutputShape(
65
    inputs,
66
    attributes,
67
    isGlobalOperator,
68
  );
69
  const kernelSize = ShapeUtil.size(adjustedAttributes.kernelShape);
70
  const op1 = 'value += _X(x);';
71
  let op2 = '';
72
  if (adjustedAttributes.countIncludePad) {
73
    op2 += `value /= float(${kernelSize});`;
74
  } else {
75
    op2 += `value /= float(${kernelSize} - pad);`;
76
  }
77
  const poolingCode = generatePoolingCode(inputs[0].dims, adjustedAttributes, op1, op2, '0.0');
78
  const shaderSource = `
79
        ${poolingCode}
80
      `;
81
  return {
82
    ...metadata,
83
    output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked },
84
    shaderSource,
85
  };
86
};
87

88
export const globalAveragePool: OperatorImplementation<AveragePoolAttributes> = (
89
  inferenceHandler: WebGLInferenceHandler,
90
  inputs: Tensor[],
91
  attributes: AveragePoolAttributes,
92
): Tensor[] => {
93
  validateInputs(inputs);
94
  const metadata = {
95
    name: 'GlobalAveragePool',
96
    inputNames: ['X'],
97
    inputTypes: [TextureType.unpacked],
98
    cacheHint: `${attributes.countIncludePad}`,
99
  };
100
  const output = inferenceHandler.run(
101
    { ...metadata, get: () => createAveragePoolProgramInfo(inputs, metadata, true, attributes) },
102
    inputs,
103
  );
104
  return [output];
105
};
106

107
export const parseGlobalAveragePoolAttributes: OperatorInitialization<AveragePoolAttributes> = (
108
  node: Graph.Node,
109
): AveragePoolAttributes => {
110
  const countIncludePad = node.attributes.getInt('count_include_pad', 0) === 0 ? false : true;
111
  return createAttributeWithCacheKey({
112
    autoPad: '',
113
    ceilMode: 0,
114
    countIncludePad,
115
    kernelShape: [],
116
    strides: [],
117
    pads: [],
118
  });
119
};
120

121
export interface MaxPoolAttributes extends AveragePoolAttributes {
122
  readonly storageOrder: number;
123
  readonly dilations: number[];
124
}
125

126
export const maxPool: OperatorImplementation<MaxPoolAttributes> = (
127
  inferenceHandler: WebGLInferenceHandler,
128
  inputs: Tensor[],
129
  attributes: MaxPoolAttributes,
130
): Tensor[] => {
131
  validateInputs(inputs);
132
  const metadata = {
133
    name: 'MaxPool',
134
    inputNames: ['X'],
135
    inputTypes: [TextureType.unpacked],
136
    cacheHint: attributes.cacheKey,
137
  };
138
  const output = inferenceHandler.run(
139
    { ...metadata, get: () => createMaxPoolProgramInfo(inputs, metadata, false, attributes) },
140
    inputs,
141
  );
142
  return [output];
143
};
144

145
export const parseMaxPoolAttributes: OperatorInitialization<MaxPoolAttributes> = (
146
  node: Graph.Node,
147
): MaxPoolAttributes => {
148
  const autoPad = node.attributes.getString('auto_pad', 'NOTSET');
149
  const ceilMode = node.attributes.getInt('ceil_mode', 0);
150
  const kernelShape = node.attributes.getInts('kernel_shape');
151
  const strides = node.attributes.getInts('strides', []);
152
  const pads = node.attributes.getInts('pads', []);
153
  const storageOrder = node.attributes.getInt('storage_order', 0);
154
  const dilations = node.attributes.getInts('dilations', []);
155

156
  // TODO: support attribute 'ceil_mode' and 'storage_order'
157
  if (storageOrder !== 0) {
158
    throw new Error('column major storage order is not yet supported for MaxPool');
159
  }
160
  if (ceilMode !== 0) {
161
    throw new Error('using ceil() in shape computation is not yet supported for MaxPool');
162
  }
163

164
  return createAttributeWithCacheKey({
165
    autoPad,
166
    ceilMode,
167
    countIncludePad: false,
168
    kernelShape,
169
    strides,
170
    pads,
171
    storageOrder,
172
    dilations,
173
  });
174
};
175

176
const createMaxPoolProgramInfo = (
177
  inputs: Tensor[],
178
  metadata: ProgramMetadata,
179
  isGlobalOperator: boolean,
180
  attributes: MaxPoolAttributes,
181
): ProgramInfo => {
182
  const [adjustedAttributes, outputShape] = getAdjustedPoolAttributesAndOutputShape(
183
    inputs,
184
    attributes,
185
    isGlobalOperator,
186
  );
187
  const op1 = `
188
      value = max(_X(x), value);
189
    `;
190
  const op2 = '';
191
  const poolingCode = generatePoolingCode(inputs[0].dims, adjustedAttributes, op1, op2, '-1e5');
192
  const shaderSource = `
193
      ${poolingCode}
194
    `;
195
  return {
196
    ...metadata,
197
    output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked },
198
    shaderSource,
199
  };
200
};
201

202
const getAdjustedPoolAttributesAndOutputShape = (
203
  inputs: Tensor[],
204
  attributes: AveragePoolAttributes | MaxPoolAttributes,
205
  isGlobalOperator: boolean,
206
): [AveragePoolAttributes | MaxPoolAttributes, number[]] => {
207
  const inputShape = inputs[0].dims.slice();
208
  const hasDilations = Object.hasOwnProperty.call(attributes, 'dilations');
209
  const kernelShape = attributes.kernelShape.slice();
210
  const strides = attributes.strides.slice();
211
  const dilations: number[] = hasDilations ? (attributes as MaxPoolAttributes).dilations.slice() : [];
212
  const pads = attributes.pads.slice();
213
  PoolConvUtil.adjustPoolAttributes(isGlobalOperator, inputShape, kernelShape, strides, dilations, pads);
214

215
  const outputShape = PoolConvUtil.computePoolOutputShape(
216
    isGlobalOperator,
217
    inputShape,
218
    strides,
219
    dilations,
220
    kernelShape,
221
    pads,
222
    attributes.autoPad,
223
  );
224

225
  const newAttributes = Object.assign({}, attributes);
226
  if (hasDilations) {
227
    Object.assign(newAttributes, { kernelShape, strides, pads, dilations, cacheKey: attributes.cacheKey });
228
  } else {
229
    Object.assign(newAttributes, { kernelShape, strides, pads, cacheKey: attributes.cacheKey });
230
  }
231
  return [newAttributes, outputShape];
232
};
233

234
const globalMaxPoolAttributes = {
235
  autoPad: '',
236
  ceilMode: 0,
237
  countIncludePad: false,
238
  kernelShape: [],
239
  strides: [],
240
  pads: [],
241
  storageOrder: 0,
242
  dilations: [],
243
  cacheKey: '',
244
};
245

246
const globalMaxPoolMetadata = {
247
  name: 'GlobalMaxPool',
248
  inputNames: ['X'],
249
  inputTypes: [TextureType.unpacked],
250
};
251

252
export const globalMaxPool = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => {
253
  validateInputs(inputs);
254
  const output = inferenceHandler.run(
255
    {
256
      ...globalMaxPoolMetadata,
257
      get: () => createMaxPoolProgramInfo(inputs, globalMaxPoolMetadata, true, globalMaxPoolAttributes),
258
    },
259
    inputs,
260
  );
261
  return [output];
262
};
263

264
const validateInputs = (inputs: Tensor[]): void => {
265
  if (!inputs || inputs.length !== 1) {
266
    throw new Error('Pool ops requires 1 input.');
267
  }
268
  if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') {
269
    throw new Error('Invalid input type.');
270
  }
271
};
272

273
const generatePoolingCode = (
274
  inputDims: readonly number[],
275
  attributes: AveragePoolAttributes,
276
  op1: string,
277
  op2: string,
278
  start: string,
279
): string => {
280
  const rank = inputDims.length;
281
  if (attributes.kernelShape.length <= 2) {
282
    const kw = attributes.kernelShape[attributes.kernelShape.length - 1];
283
    const sw = attributes.strides[attributes.strides.length - 1];
284
    const pwStart = attributes.pads[attributes.pads.length / 2 - 1];
285
    const pwEnd = attributes.pads[attributes.pads.length - 1];
286
    const dimW = inputDims[rank - 1];
287
    let codeW = '';
288
    let codeH = '';
289
    let codeHEnd = '';
290
    if (pwStart + pwEnd !== 0) {
291
      codeW = `
292
          for (int i = 0; i < ${kw}; i++) {
293
            x[${rank} - 1] = indices[${rank} - 1] * ${sw} - ${pwStart} + i;
294
            if (x[${rank} - 1] < 0 || x[${rank} - 1] >= ${dimW}) {
295
              pad++;
296
              continue;
297
            }
298
            ${op1}
299
          }`;
300
    } else {
301
      codeW = `
302
          for (int i = 0; i < ${kw}; i++) {
303
            x[${rank} - 1] = indices[${rank} - 1] * ${sw} - ${pwStart} + i;
304
            ${op1}
305
          }`;
306
    }
307

308
    if (attributes.kernelShape.length === 2) {
309
      const kh = attributes.kernelShape[attributes.kernelShape.length - 2];
310
      const sh = attributes.strides[attributes.strides.length - 2];
311
      const phStart = attributes.pads[attributes.pads.length / 2 - 2];
312
      const phEnd = attributes.pads[attributes.pads.length - 2];
313
      const dimH = inputDims[rank - 2];
314
      if (phStart + phEnd !== 0) {
315
        codeH = `
316
            for (int j = 0; j < ${kh}; j++) {
317
              x[${rank} - 2] = indices[${rank} - 2] * ${sh} - ${phStart} + j;
318
              if (x[${rank} - 2] < 0 || x[${rank} - 2] >= ${dimH}) {
319
                pad+= ${kw};
320
                continue;
321
              }
322
          `;
323
      } else {
324
        codeH = `
325
            for (int j = 0; j < ${kh}; j++) {
326
              x[${rank} - 2] = indices[${rank} - 2] * ${sh} - ${phStart} + j;
327
            `;
328
      }
329
      codeHEnd = `
330
          }
331
        `;
332
    }
333

334
    const poolingCode = `
335
        float process(int indices[${rank}]) {
336
          int x[${rank}];
337
          copyVec(indices, x);
338

339
          float value = ${start};
340
          int pad = 0;
341
          ${codeH}
342
          ${codeW}
343
          ${codeHEnd}
344
          ${op2}
345
          return value;
346
        }
347
      `;
348
    return poolingCode;
349
  } else {
350
    const kernelSize = ShapeUtil.size(attributes.kernelShape);
351
    const kernelStrides = ShapeUtil.computeStrides(attributes.kernelShape);
352
    const stridesRank = kernelStrides.length;
353
    const padsRank = attributes.pads.length;
354
    const offsetToIndicesFunction = offsetToIndices(stridesRank);
355
    const copyInputDims = copyArray(inputDims, 'inputDims');
356
    const copyPads = copyArray(attributes.pads, 'pads');
357
    const copyKernelStrides = copyArray(kernelStrides, 'kernelStrides');
358
    const copyStrides = copyArray(attributes.strides, 'strides');
359
    const hasPads = attributes.pads.reduce((sum, cur) => sum + cur);
360
    let padCode = '';
361
    if (hasPads) {
362
      padCode = `
363
            if (x[j] >= inputDims[j] || x[j] < 0) {
364
              pad++;
365
              isPad = true;
366
              break;
367
            }
368
          }
369
          if (!isPad) {
370
            ${op1}
371
          }`;
372
    } else {
373
      padCode = `
374
          }
375
          ${op1}
376
        `;
377
    }
378
    const poolingCode = `
379
        ${offsetToIndicesFunction}
380
        float process(int indices[${rank}]) {
381
          int x[${rank}];
382
          copyVec(indices, x);
383
          int offset[${stridesRank}];
384
          int pads[${padsRank}];
385
          int inputDims[${rank}];
386
          int kernelStrides[${stridesRank}];
387
          int strides[${stridesRank}];
388
          ${copyPads}
389
          ${copyInputDims}
390
          ${copyStrides}
391
          ${copyKernelStrides}
392

393
          float value = ${start};
394
          int pad = 0;
395
          bool isPad = false;
396
          for (int i = 0; i < ${kernelSize}; i++) {
397
            offsetToIndices(i, kernelStrides, offset);
398
            isPad = false;
399
            for (int j = ${rank} - ${stridesRank}; j < ${rank}; j++) {
400
              x[j] = indices[j] * strides[j - ${rank} + ${stridesRank}]
401
                + offset[j - ${rank} + ${stridesRank}] - pads[j - 2];
402
              ${padCode}
403
          }
404
          ${op2}
405

406
          return value;
407
        }
408
      `;
409
    return poolingCode;
410
  }
411
};
412

413
const copyArray = (array: readonly number[], arrayName: string): string => {
414
  let block = '';
415
  for (let i = 0; i < array.length; i++) {
416
    block += `
417
      ${arrayName}[${i}] = ${array[i]};
418
    `;
419
  }
420
  return block;
421
};
422

423
const offsetToIndices = (rank: number): string => `
424
  void offsetToIndices(int offset, int[${rank}] strides, out int[${rank}] indices) {
425
    if (${rank} == 0) {
426
      return;
427
    }
428
    for (int i = 0; i < ${rank} - 1; ++i) {
429
      indices[i] = offset / strides[i];
430
      offset -= indices[i] * strides[i];
431
    }
432
    indices[${rank} - 1] = offset;
433
  }`;
434

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.