onnxruntime

Форк
0
154 строки · 5.3 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { NUMBER_TYPES, OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { ShapeUtil } from '../../../util';
9
import { WebGLInferenceHandler } from '../inference-handler';
10
import { ProgramInfo, TextureType } from '../types';
11

12
export interface SliceAttributes extends AttributeWithCacheKey {
13
  readonly axes: number[];
14
  readonly ends: number[];
15
  readonly starts: number[];
16
}
17

18
const sliceProgramMetadata = {
19
  name: 'Slice',
20
  inputNames: ['A'],
21
  inputTypes: [TextureType.unpacked],
22
};
23

24
export const slice: OperatorImplementation<SliceAttributes> = (
25
  inferenceHandler: WebGLInferenceHandler,
26
  inputs: Tensor[],
27
  attributes: SliceAttributes,
28
): Tensor[] => {
29
  validateInputs(inputs);
30
  const output = inferenceHandler.run(
31
    {
32
      ...sliceProgramMetadata,
33
      cacheHint: attributes.cacheKey,
34
      get: () => createSliceProgramInfo(inferenceHandler, inputs[0], attributes),
35
    },
36
    inputs,
37
  );
38
  return [output];
39
};
40

41
export const parseSliceAttributes: OperatorInitialization<SliceAttributes> = (node: Graph.Node): SliceAttributes => {
42
  const starts = node.attributes.getInts('starts');
43
  const ends = node.attributes.getInts('ends');
44
  const axes = node.attributes.getInts('axes', []);
45
  return createAttributeWithCacheKey({ starts, ends, axes });
46
};
47

48
const createSliceProgramInfo = (
49
  _inferenceHandler: WebGLInferenceHandler,
50
  input: Tensor,
51
  attributes: SliceAttributes,
52
): ProgramInfo => {
53
  const axes = attributes.axes.length === 0 ? input.dims.slice(0).map((_val, i) => i) : attributes.axes;
54
  const normalizedAxes = ShapeUtil.normalizeAxes(axes, input.dims.length);
55
  const starts = attributes.starts.map((start, i) => {
56
    if (start > input.dims[normalizedAxes[i]] - 1) {
57
      return input.dims[normalizedAxes[i]];
58
    }
59
    return ShapeUtil.normalizeAxis(start, input.dims[normalizedAxes[i]]);
60
  });
61
  const ends = attributes.ends.map((end, i) => {
62
    if (end > input.dims[normalizedAxes[i]] - 1) {
63
      return input.dims[normalizedAxes[i]];
64
    }
65
    return ShapeUtil.normalizeAxis(end, input.dims[normalizedAxes[i]]);
66
  });
67

68
  const outputShape = input.dims.slice();
69

70
  const sliceOps: string[] = [];
71
  for (let i = 0; i < normalizedAxes.length; i++) {
72
    outputShape[normalizedAxes[i]] = ends[i] - starts[i];
73
    if (starts[i] > 0) {
74
      sliceOps.push(`outputIdx[${normalizedAxes[i]}] += ${starts[i]};`);
75
    } // else { sliceOps.push(`outputIdx[${normalizedAxes[i]}] += 0;`); }
76
  }
77

78
  const rank = outputShape.length;
79
  const shaderSource = `
80
      float process(int outputIdx[${rank}]) {
81
        ${sliceOps.join('\n      ')}
82
        return _A(outputIdx);
83
      }`;
84
  return {
85
    ...sliceProgramMetadata,
86
    output: { dims: outputShape, type: input.type, textureType: TextureType.unpacked },
87
    shaderSource,
88
  };
89
};
90

91
const validateInputs = (inputs: Tensor[]): void => {
92
  if (!inputs || inputs.length !== 1) {
93
    throw new Error('Slice requires 1 input.');
94
  }
95
  if (NUMBER_TYPES.indexOf(inputs[0].type) === -1) {
96
    throw new Error('Invalid input type.');
97
  }
98
};
99

100
export const sliceV10 = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => {
101
  validateInputsV10(inputs);
102
  const attributes = generateSliceAttributesFromInputs(inferenceHandler, inputs);
103
  const output = inferenceHandler.run(
104
    {
105
      ...sliceProgramMetadata,
106
      cacheHint: attributes.cacheKey,
107
      get: () => createSliceProgramInfo(inferenceHandler, inputs[0], attributes),
108
    },
109
    [inputs[0]],
110
  );
111
  return [output];
112
};
113

114
const generateSliceAttributesFromInputs = (
115
  inferenceHandler: WebGLInferenceHandler,
116
  inputs: Tensor[],
117
): SliceAttributes => {
118
  if (
119
    !inferenceHandler.session.isInitializer(inputs[1].dataId) ||
120
    !inferenceHandler.session.isInitializer(inputs[2].dataId) ||
121
    (inputs.length >= 4 && !inferenceHandler.session.isInitializer(inputs[3].dataId)) ||
122
    (inputs.length >= 5 && !inferenceHandler.session.isInitializer(inputs[4].dataId))
123
  ) {
124
    throw new Error('dynamic slice attributes are not allowed');
125
  }
126

127
  if (inputs.length >= 5 && inputs[4].integerData.some((i: number) => i !== 1)) {
128
    throw new Error('currently non-1 steps is not supported for Slice');
129
  }
130

131
  const starts = Array.from(inputs[1].integerData);
132
  const ends = Array.from(inputs[2].integerData);
133
  const axes = inputs.length >= 4 ? Array.from(inputs[3].integerData) : [];
134
  const cacheKey = `${axes};${starts};${ends}`;
135
  return { starts, ends, axes, cacheKey };
136
};
137

138
const validateInputsV10 = (inputs: Tensor[]): void => {
139
  if (!inputs || inputs.length < 3 || inputs.length > 5) {
140
    throw new Error('Invalid input number.');
141
  }
142
  if (inputs[1].type !== 'int32' || inputs[1].dims.length !== 1) {
143
    throw new Error('Invalid input type.');
144
  }
145
  if (inputs[2].type !== 'int32' || inputs[2].dims.length !== 1) {
146
    throw new Error('Invalid input type.');
147
  }
148
  if (inputs.length >= 4 && (inputs[3].type !== 'int32' || inputs[3].dims.length !== 1)) {
149
    throw new Error('Invalid input type.');
150
  }
151
  if (inputs.length >= 5 && (inputs[4].type !== 'int32' || inputs[4].dims.length !== 1)) {
152
    throw new Error('Invalid input type.');
153
  }
154
};
155

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.