onnxruntime

Форк
0
70 строк · 2.1 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { NUMBER_TYPES } from '../../../operators';
5
import { Tensor } from '../../../tensor';
6
import { WebGLInferenceHandler } from '../inference-handler';
7
import { ProgramInfo, ProgramMetadata, TextureType } from '../types';
8

9
export const tile = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => {
10
  validateInputs(inputs);
11

12
  const tileProgramMetadata = {
13
    name: 'Tile',
14
    inputNames: ['A'],
15
    inputTypes: [TextureType.unpacked],
16
  };
17

18
  const output = inferenceHandler.run(
19
    { ...tileProgramMetadata, get: () => createTileProgramInfo(inferenceHandler, inputs, tileProgramMetadata) },
20
    inputs,
21
  );
22
  return [output];
23
};
24

25
const createTileProgramInfo = (
26
  _handler: WebGLInferenceHandler,
27
  inputs: Tensor[],
28
  tileProgramMetadata: ProgramMetadata,
29
): ProgramInfo => {
30
  const inputShape = inputs[0].dims.slice();
31
  const outputShape = new Array(inputShape.length);
32

33
  const tileOps: string[] = [];
34
  for (let i = 0; i < inputShape.length; i++) {
35
    outputShape[i] = inputShape[i] * inputs[1].numberData[i];
36
    tileOps.push(`inputIdx[${i}] = int(mod(float(outputIdx[${i}]), ${inputShape[i]}.));`);
37
  }
38

39
  const rank = outputShape.length;
40
  const shaderSource = `
41
      float process(int outputIdx[${rank}]) {
42
        int inputIdx[${rank}];
43
        ${tileOps.join('\n')}
44
        return _A(inputIdx);
45
      }
46
    `;
47
  return {
48
    ...tileProgramMetadata,
49
    output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked },
50
    shaderSource,
51
  };
52
};
53

54
const validateInputs = (inputs: Tensor[]): void => {
55
  if (!inputs || inputs.length !== 2) {
56
    throw new Error('Tile requires 2 input.');
57
  }
58
  if (inputs[1].dims.length !== 1) {
59
    throw new Error('The second input shape must 1 dimension.');
60
  }
61
  if (inputs[1].dims[0] !== inputs[0].dims.length) {
62
    throw new Error('Invalid input shape.');
63
  }
64
  if (NUMBER_TYPES.indexOf(inputs[0].type) === -1) {
65
    throw new Error('Invalid input type.');
66
  }
67
  if (inputs[1].type !== 'int32' && inputs[1].type !== 'int16') {
68
    throw new Error('Invalid repeat type.');
69
  }
70
};
71

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.