1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { NUMBER_TYPES } from '../../../operators';
5
import { Tensor } from '../../../tensor';
6
import { WebGLInferenceHandler } from '../inference-handler';
7
import { ProgramInfo, ProgramMetadata, TextureType } from '../types';
9
export const tile = (inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] => {
10
validateInputs(inputs);
12
const tileProgramMetadata = {
15
inputTypes: [TextureType.unpacked],
18
const output = inferenceHandler.run(
19
{ ...tileProgramMetadata, get: () => createTileProgramInfo(inferenceHandler, inputs, tileProgramMetadata) },
25
const createTileProgramInfo = (
26
_handler: WebGLInferenceHandler,
28
tileProgramMetadata: ProgramMetadata,
30
const inputShape = inputs[0].dims.slice();
31
const outputShape = new Array(inputShape.length);
33
const tileOps: string[] = [];
34
for (let i = 0; i < inputShape.length; i++) {
35
outputShape[i] = inputShape[i] * inputs[1].numberData[i];
36
tileOps.push(`inputIdx[${i}] = int(mod(float(outputIdx[${i}]), ${inputShape[i]}.));`);
39
const rank = outputShape.length;
40
const shaderSource = `
41
float process(int outputIdx[${rank}]) {
42
int inputIdx[${rank}];
48
...tileProgramMetadata,
49
output: { dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked },
54
const validateInputs = (inputs: Tensor[]): void => {
55
if (!inputs || inputs.length !== 2) {
56
throw new Error('Tile requires 2 input.');
58
if (inputs[1].dims.length !== 1) {
59
throw new Error('The second input shape must 1 dimension.');
61
if (inputs[1].dims[0] !== inputs[0].dims.length) {
62
throw new Error('Invalid input shape.');
64
if (NUMBER_TYPES.indexOf(inputs[0].type) === -1) {
65
throw new Error('Invalid input type.');
67
if (inputs[1].type !== 'int32' && inputs[1].type !== 'int16') {
68
throw new Error('Invalid repeat type.');