1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { ShapeUtil } from '../../../util';
9
import { WebGLInferenceHandler } from '../inference-handler';
10
import { ProgramInfo, TextureType } from '../types';
12
export interface TransposeAttributes extends AttributeWithCacheKey {
13
readonly perm: number[];
16
const transposeProgramMetadata = {
19
inputTypes: [TextureType.unpacked],
22
export const transpose: OperatorImplementation<TransposeAttributes> = (
23
inferenceHandler: WebGLInferenceHandler,
25
attributes: TransposeAttributes,
27
validateInputs(inputs);
28
const output = inferenceHandler.run(
30
...transposeProgramMetadata,
31
cacheHint: attributes.cacheKey,
32
get: () => createTransposeProgramInfo(inferenceHandler, inputs[0], attributes.perm),
39
export const parseTransposeAttributes: OperatorInitialization<TransposeAttributes> = (
41
): TransposeAttributes => createAttributeWithCacheKey({ perm: node.attributes.getInts('perm', []) });
43
const createTransposeProgramInfo = (
44
_inferenceHandler: WebGLInferenceHandler,
48
const inputShape = input.dims;
49
perm = getAdjustedPerm(inputShape, perm);
50
const unpackedOutputShape = getOutputShape(inputShape, perm);
51
const rank = inputShape.length;
52
// A dims=[${inputs[0].dims.toString()}]
53
// out Dims=[${unpackedOutputShape.toString()}]
54
// based on perm=[${perm.toString()}]
55
const shaderSource = `
56
${getPermFunctionBody('perm', perm, rank)}
57
float process(int indices[${rank}]) {
63
...transposeProgramMetadata,
64
output: { dims: unpackedOutputShape, type: input.type, textureType: TextureType.unpacked },
69
const getAdjustedPerm = (inputShape: readonly number[], perm: number[]): number[] => {
70
if (perm && perm.length !== inputShape.length) {
71
perm = [...inputShape.keys()].reverse();
76
const getOutputShape = (inputShape: readonly number[], perm: number[]): readonly number[] => {
77
perm = getAdjustedPerm(inputShape, perm);
78
return ShapeUtil.sortBasedOnPerm(inputShape, perm);
81
const getPermFunctionBody = (name: string, perm: number[], rank: number): string => {
82
const reverseFunc = [];
83
reverseFunc.push(`void ${name}(out int a[${rank}], int src[${rank}]) {`);
84
for (let i = 0; i < rank; ++i) {
85
reverseFunc.push(`\ta[${perm[i]}]=src[${i}];`);
87
reverseFunc.push('\t}');
88
return reverseFunc.join('\n');
91
const validateInputs = (inputs: Tensor[]): void => {
92
if (!inputs || inputs.length !== 1) {
93
throw new Error('Transpose requires 1 input.');
96
if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') {
97
throw new Error('input should be float tensor');