onnxruntime

Форк
0
58 строк · 2.3 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { Tensor } from '../../../tensor';
5
import { WebGLInferenceHandler } from '../inference-handler';
6

7
import { calculateOutputShape, ConvAttributes } from './conv';
8
import { createPackedIm2ColProgramInfoLoader } from './im2col-pack';
9
import { createPackedMatmulProgramInfoLoader } from './matmul-pack';
10

11
export const conv2DPackedPointwise = (
12
  inferenceHandler: WebGLInferenceHandler,
13
  inputs: readonly Tensor[],
14
  attributes: ConvAttributes,
15
): Tensor => {
16
  const xshape = inputs[0].dims;
17
  const kshape = inputs[1].dims;
18
  const outputShape = calculateOutputShape(xshape, kshape, attributes.dilations, attributes.pads, attributes.strides);
19
  const reshapedX = inferenceHandler.reshapePacked(inputs[0], [xshape[1], xshape[2] * xshape[3]]);
20
  const reshapedK = inferenceHandler.reshapePacked(inputs[1], [kshape[0], kshape[1]]);
21

22
  const matmulInputs = inputs.length > 2 ? [reshapedK, reshapedX, inputs[2]] : [reshapedK, reshapedX];
23
  const matmulOutput = inferenceHandler.run(
24
    createPackedMatmulProgramInfoLoader(inferenceHandler, matmulInputs, attributes),
25
    matmulInputs,
26
  );
27
  return inferenceHandler.reshapePacked(matmulOutput, outputShape);
28
};
29

30
export const conv2DPacked = (
31
  inferenceHandler: WebGLInferenceHandler,
32
  inputs: readonly Tensor[],
33
  attributes: ConvAttributes,
34
): Tensor => {
35
  const xshape = inputs[0].dims;
36
  const kshape = inputs[1].dims;
37
  const outputShape = calculateOutputShape(xshape, kshape, attributes.dilations, attributes.pads, attributes.strides);
38

39
  // run im2col
40
  const im2colOutput = inferenceHandler.run(
41
    createPackedIm2ColProgramInfoLoader(inferenceHandler, inputs[0], inputs[1], outputShape, attributes),
42
    [inputs[0]],
43
  );
44

45
  // reshape kernel
46
  const kernelReshaped = inferenceHandler.reshapePacked(inputs[1], [kshape[0], kshape[1] * kshape[2] * kshape[3]]);
47

48
  // run matmul
49
  const matmulInputs = inputs.length === 3 ? [kernelReshaped, im2colOutput, inputs[2]] : [kernelReshaped, im2colOutput];
50
  const matmulOutput = inferenceHandler.run(
51
    createPackedMatmulProgramInfoLoader(inferenceHandler, matmulInputs, attributes),
52
    matmulInputs,
53
  );
54

55
  // reshape output
56
  const outputReshaped = inferenceHandler.reshapePacked(matmulOutput, outputShape);
57
  return outputReshaped;
58
};
59

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.