1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { Graph } from '../../../graph';
5
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
6
import { Tensor } from '../../../tensor';
7
import { WebGLInferenceHandler } from '../inference-handler';
9
import { transpose, TransposeAttributes } from './transpose';
11
export interface DepthToSpaceAttributes {
16
export const depthToSpace: OperatorImplementation<DepthToSpaceAttributes> = (
17
inferenceHandler: WebGLInferenceHandler,
19
attributes: DepthToSpaceAttributes,
21
validateInputs(inputs);
22
const blocksize = attributes.blocksize;
23
const blocksizeSqr = blocksize * blocksize;
24
const transposePerm = attributes.mode === 'DCR' ? [0, 3, 4, 1, 5, 2] : [0, 1, 4, 2, 5, 3];
25
const firstReshapeShape =
26
attributes.mode === 'DCR'
31
inputs[0].dims[1] / blocksizeSqr,
37
inputs[0].dims[1] / blocksizeSqr,
44
// const transpose = new WebGLTranspose();
45
// const attributes = new Attribute(undefined);
46
// attributes.set('perm', 'ints', transposePerm);
47
// transpose.initialize(attributes);
50
const firstReshapedTensor = inferenceHandler.reshapeUnpacked(inputs[0], firstReshapeShape);
53
const transposeAttributes: TransposeAttributes = { perm: transposePerm, cacheKey: `${transposePerm}` };
54
const [transposeOutput] = transpose(inferenceHandler, [firstReshapedTensor], transposeAttributes);
57
const secondReshapeShape = [
59
inputs[0].dims[1] / blocksizeSqr,
60
inputs[0].dims[2] * blocksize,
61
inputs[0].dims[3] * blocksize,
63
const result = inferenceHandler.reshapeUnpacked(transposeOutput, secondReshapeShape);
67
export const parseDepthToSpaceAttributes: OperatorInitialization<DepthToSpaceAttributes> = (
69
): DepthToSpaceAttributes => {
70
// processing node attributes
71
const blocksize = node.attributes.getInt('blocksize');
73
throw new Error(`blocksize must be >= 1, but got : ${blocksize} for DepthToSpace`);
75
const mode = node.attributes.getString('mode', 'DCR');
76
if (mode !== 'DCR' && mode !== 'CRD') {
77
throw new Error(`unrecognized mode: ${mode} for DepthToSpace`);
79
return { mode, blocksize };
82
const validateInputs = (inputs: Tensor[]): void => {
83
if (inputs.length !== 1) {
84
throw new Error(`DepthToSpace expect 1 inputs, but got ${inputs.length}`);
87
// Input has to be a 4-D tensor
88
// TODO: Support string depth-to-space.
89
if (inputs[0].type === 'string' || inputs[0].dims.length !== 4) {
90
throw new TypeError('DepthToSpace input should be a 4-D numeric tensor');