onnxruntime

Форк
0
99 строк · 3.4 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { WebGLInferenceHandler } from '../inference-handler';
9
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
10

11
export interface ImageScalerAttributes extends AttributeWithCacheKey {
12
  scale: number;
13
  bias: number[];
14
}
15

16
export const imageScaler: OperatorImplementation<ImageScalerAttributes> = (
17
  inferenceHandler: WebGLInferenceHandler,
18
  inputs: Tensor[],
19
  attributes: ImageScalerAttributes,
20
): Tensor[] => {
21
  validateInputs(inputs);
22
  const output = inferenceHandler.run(createImageScalerProgramInfoLoader(inferenceHandler, inputs, attributes), inputs);
23
  return [output];
24
};
25

26
export const parseImageScalerAttributes: OperatorInitialization<ImageScalerAttributes> = (
27
  node: Graph.Node,
28
): ImageScalerAttributes => {
29
  const scale = node.attributes.getFloat('scale');
30
  const bias = node.attributes.getFloats('bias');
31
  return createAttributeWithCacheKey({ scale, bias });
32
};
33

34
const imageScalerProgramMetadata = {
35
  name: 'ImageScaler',
36
  inputNames: ['X'],
37
  inputTypes: [TextureType.unpacked],
38
};
39

40
const createImageScalerProgramInfo = (
41
  _handler: WebGLInferenceHandler,
42
  metadata: ProgramMetadata,
43
  inputs: Tensor[],
44
  attributes: ImageScalerAttributes,
45
): ProgramInfo => {
46
  const outputShape = inputs[0].dims.slice();
47
  const rank = outputShape.length;
48
  const getBiasMethod = createGetBiasMethod(attributes.bias.length);
49
  const shaderSource = `
50
      ${getBiasMethod}
51
      float process(int indices[${rank}]) {
52
        return _X(indices) * scale + getBias(bias, indices[1]);
53
      }`;
54
  return {
55
    ...metadata,
56
    output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked },
57
    variables: [
58
      { name: 'bias', type: 'float', arrayLength: attributes.bias.length, data: attributes.bias },
59
      { name: 'scale', type: 'float', data: attributes.scale },
60
    ],
61
    shaderSource,
62
  };
63
};
64

65
const createImageScalerProgramInfoLoader = (
66
  handler: WebGLInferenceHandler,
67
  inputs: Tensor[],
68
  attributes: ImageScalerAttributes,
69
): ProgramInfoLoader => {
70
  const metadata = { ...imageScalerProgramMetadata, cacheHint: attributes.cacheKey };
71
  return { ...metadata, get: () => createImageScalerProgramInfo(handler, metadata, inputs, attributes) };
72
};
73

74
const createGetBiasMethod = (numChannels: number): string => {
75
  const codeLines: string[] = [`float getBias(float bias[${numChannels}], int channel) {`];
76
  for (let i = 0; i < numChannels; ++i) {
77
    if (i === 0) {
78
      codeLines.push('\t' + `if (channel == ${i}) { return bias[${i}]; }`);
79
    } else if (i === numChannels - 1) {
80
      codeLines.push('\t' + `else { return bias[${i}]; }`);
81
    } else {
82
      codeLines.push('\t' + `else if (channel == ${i}) { return bias[${i}]; }`);
83
    }
84
  }
85
  codeLines.push('\t' + '}');
86
  return codeLines.join('\n');
87
};
88

89
const validateInputs = (inputs: Tensor[]): void => {
90
  if (!inputs || inputs.length !== 1) {
91
    throw new Error('ImageScaler requires 1 input.');
92
  }
93
  if (inputs[0].dims.length !== 4) {
94
    throw new Error('Invalid input shape.');
95
  }
96
  if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') {
97
    throw new Error('Invalid input type.');
98
  }
99
};
100

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.