onnxruntime

Форк
0
228 строк · 8.4 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { DataType } from '../../../wasm-common';
5
import { TensorView } from '../../tensor-view';
6
import { ShapeUtil } from '../../util';
7
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
8
import { ComputeContext, ProgramInfo, ProgramUniform, TensorInfo } from '../types';
9

10
import {
11
  createTensorShapeVariables,
12
  getElementAt,
13
  IndicesHelper,
14
  inputVariable,
15
  outputVariable,
16
  ShaderHelper,
17
  UniformsArrayType,
18
} from './common';
19

20
export interface SliceAttributes extends AttributeWithCacheKey {
21
  readonly starts: number[];
22
  readonly ends: number[];
23
  readonly axes: number[];
24
}
25

26
const validateInputs = (inputs: readonly TensorView[], attributes: SliceAttributes): void => {
27
  if (!inputs || inputs.length < 1) {
28
    throw new Error('too few inputs');
29
  }
30
  if (attributes.axes.length !== 0) {
31
    if (attributes.axes.length !== attributes.starts.length || attributes.axes.length !== attributes.ends.length) {
32
      throw new Error('axes, starts and ends must have the same length');
33
    }
34
  } else if (attributes.starts.length !== attributes.ends.length) {
35
    throw new Error('starts and ends must have the same length');
36
  }
37
  inputs.slice(1).forEach((_, idx) => {
38
    if (inputs[idx + 1].dataType !== DataType.int32 && inputs[idx + 1].dataType !== DataType.int64) {
39
      throw new Error(`Input ${idx} must be an array of int32 or int64`);
40
    }
41
  });
42
};
43

44
const readInput = (inputs: readonly TensorView[], idx: number): number[] => {
45
  const input: number[] = [];
46
  if (inputs.length > idx) {
47
    if (inputs[idx].dataType === DataType.int64) {
48
      inputs[idx].getBigInt64Array().forEach((v) => input.push(Number(v)));
49
    } else if (inputs[idx].dataType === DataType.int32) {
50
      inputs[idx].getInt32Array().forEach((v) => input.push(Number(v)));
51
    } else {
52
      throw new Error(`Input ${idx} must be an array of int32 or int64`);
53
    }
54
  }
55
  return input;
56
};
57

58
const createSliceAttributesFromInputs = (
59
  inputs: readonly TensorView[],
60
  attributes: SliceAttributes,
61
): SliceAttributes => {
62
  if (inputs.length > 1) {
63
    const starts: number[] = readInput(inputs, 1);
64
    const ends: number[] = readInput(inputs, 2);
65
    let axes: number[] = readInput(inputs, 3);
66
    if (axes.length === 0) {
67
      axes = [...Array(inputs[0].dims.length).keys()];
68
    }
69
    return createAttributeWithCacheKey({ starts, ends, axes });
70
  } else {
71
    return attributes;
72
  }
73
};
74

75
const fixStartEndValues = (
76
  value: number,
77
  index: number,
78
  inputShape: readonly number[],
79
  axes: readonly number[],
80
  steps: readonly number[],
81
): number => {
82
  let newValue = value;
83
  if (value < 0) {
84
    newValue += inputShape[axes[index]];
85
  }
86
  if (steps[index] < 0) {
87
    return Math.max(0, Math.min(newValue, inputShape[axes[index]] - 1));
88
  } else {
89
    return Math.max(0, Math.min(newValue, inputShape[axes[index]]));
90
  }
91
};
92

93
const calculateInputIndicesImpl = (
94
  input: IndicesHelper,
95
  output: IndicesHelper,
96
  inputShape: readonly number[],
97
): string =>
98
  `fn calculateInputIndices(output_indices: ${output.type.indices}) -> ${input.type.indices} {
99
          var input_indices: ${input.type.indices};
100
          var carry = 0u;
101
          for (var i = ${inputShape.length}; i >= 0; i--) {
102
            let input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)};
103
            let steps_i = ${getElementAt('uniforms.steps', 'i', inputShape.length)};
104
            let signs_i = ${getElementAt('uniforms.signs', 'i', inputShape.length)};
105
            let starts_i = ${getElementAt('uniforms.starts', 'i', inputShape.length)};
106
            var output_index = ${output.indicesGet('output_indices', 'i')};
107
            var input_index = output_index * steps_i + starts_i + carry;
108
            carry = input_index / input_shape_i;
109
            input_index = input_index % input_shape_i;
110
            if (signs_i < 0) {
111
              input_index = input_shape_i - input_index - 1u + starts_i;
112
            }
113
            ${input.indicesSet('input_indices', 'i', 'input_index')};
114
          }
115
          return input_indices;
116
      }`;
117

118
const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: SliceAttributes): ProgramInfo => {
119
  const inputShape = inputs[0].dims;
120
  const inputSize = ShapeUtil.size(inputShape);
121
  const axes =
122
    attributes.axes.length > 0
123
      ? ShapeUtil.normalizeAxes(attributes.axes, inputShape.length)
124
      : [...Array(inputShape.length).keys()];
125
  let steps = readInput(inputs, 4);
126
  steps.forEach(
127
    (step) =>
128
      step !== 0 ||
129
      (() => {
130
        throw new Error('step cannot be 0');
131
      }),
132
  );
133
  if (steps.length === 0) {
134
    steps = Array(axes.length).fill(1);
135
  }
136
  const starts = attributes.starts.map((start, i) => fixStartEndValues(start, i, inputShape, axes, steps));
137

138
  const ends = attributes.ends.map((end, i) => fixStartEndValues(end, i, inputShape, axes, steps));
139

140
  if (axes.length !== starts.length || axes.length !== ends.length) {
141
    throw new Error('start, ends and axes should have the same number of elements');
142
  }
143

144
  if (axes.length !== inputShape.length) {
145
    for (let i = 0; i < inputShape.length; ++i) {
146
      if (!axes.includes(i)) {
147
        starts.splice(i, 0, 0);
148
        ends.splice(i, 0, inputShape[i]);
149
        steps.splice(i, 0, 1);
150
      }
151
    }
152
  }
153
  const signs = steps.map((step) => Math.sign(step));
154
  // Convert negative steps to positive steps and reverse starts and ends
155
  steps.forEach((step, i, array) => {
156
    if (step < 0) {
157
      const numSteps = (ends[i] - starts[i]) / step;
158
      const newEnd = starts[i];
159
      const newStart = newEnd + numSteps * steps[i];
160
      starts[i] = newStart;
161
      ends[i] = newEnd;
162
      array[i] = -step;
163
    }
164
  });
165
  // Output rank is expected to be less than or equal to the input rank.
166
  const outputShape = inputShape.slice(0);
167
  axes.forEach((axis, _) => {
168
    outputShape[axis] = Math.ceil((ends[axis] - starts[axis]) / steps[axis]);
169
  });
170
  const outputTensorInfo: TensorInfo = { dims: outputShape, dataType: inputs[0].dataType };
171

172
  const output = outputVariable('output', inputs[0].dataType, outputShape.length);
173
  const input = inputVariable('input', inputs[0].dataType, inputs[0].dims.length);
174
  const outputSize = ShapeUtil.size(outputShape);
175
  const uniforms: UniformsArrayType = [
176
    { name: 'outputSize', type: 'u32' },
177
    { name: 'starts', type: 'u32', length: starts.length },
178
    { name: 'signs', type: 'i32', length: signs.length },
179
    { name: 'steps', type: 'u32', length: steps.length },
180
  ];
181

182
  const programUniforms: ProgramUniform[] = [
183
    { type: DataType.uint32, data: outputSize },
184
    { type: DataType.uint32, data: starts },
185
    { type: DataType.int32, data: signs },
186
    { type: DataType.uint32, data: steps },
187
    ...createTensorShapeVariables(inputs[0].dims, outputShape),
188
  ];
189

190
  const getShaderSource = (shaderHelper: ShaderHelper) => `
191
      ${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)}
192
        ${calculateInputIndicesImpl(input, output, inputShape)}
193
        ${shaderHelper.mainStart()}
194
          ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
195
          let output_indices = ${output.offsetToIndices('global_idx')};
196
          let input_indices = calculateInputIndices(output_indices);
197
          ${output.setByOffset('global_idx', input.getByIndices('input_indices'))}
198
      }`;
199
  return {
200
    name: 'Slice',
201
    shaderCache: { hint: `${signs.length}_${starts.length}_${steps.length}`, inputDependencies: ['rank'] },
202
    getShaderSource,
203
    getRunData: () => ({
204
      outputs: [outputTensorInfo],
205
      dispatchGroup: { x: Math.ceil(inputSize / 64 /* workgroup size */) },
206
      programUniforms,
207
    }),
208
  };
209
};
210

211
export const slice = (context: ComputeContext, attributes: SliceAttributes): void => {
212
  validateInputs(context.inputs, attributes);
213
  const updatedAttributes = createSliceAttributesFromInputs(context.inputs, attributes);
214
  context.compute(createSliceProgramInfo(context.inputs, updatedAttributes), { inputs: [0] });
215
  // if (ShapeUtil.size(program.outputs[0].dims) > 0) {
216
  //   context.compute(programInfoLoader, {inputs: [0]});
217
  // } else {
218
  //   // TODO: support empty output
219
  //   throw new Error('slice: output size is 0');
220
  // }
221
};
222

223
export const parseSliceAttributes = (attributes: Record<string, unknown>): SliceAttributes => {
224
  const starts = attributes.starts as number[];
225
  const ends = attributes.ends as number[];
226
  const axes = attributes.axes as number[];
227
  return createAttributeWithCacheKey({ starts, ends, axes });
228
};
229

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.