1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
4
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
5
import { Graph } from '../../../graph';
6
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
7
import { Tensor } from '../../../tensor';
8
import { WebGLInferenceHandler } from '../inference-handler';
9
import { ProgramInfo, ProgramInfoLoader, TextureType } from '../types';
11
export interface LrnAttributes extends AttributeWithCacheKey {
18
export const lrn: OperatorImplementation<LrnAttributes> = (
19
inferenceHandler: WebGLInferenceHandler,
21
attributes: LrnAttributes,
23
validateInputs(inputs);
25
// if (inferenceHandler.session.pack) {
26
// return [inferenceHandler.run(createPackedLrnProgramInfoLoader(inferenceHandler, inputs, attributes),
29
return [inferenceHandler.run(createLrnProgramInfoLoader(inputs, attributes), inputs)];
33
export const parseLrnAttributes: OperatorInitialization<LrnAttributes> = (node: Graph.Node): LrnAttributes => {
34
const alpha = node.attributes.getFloat('alpha', 0.0001);
35
const beta = node.attributes.getFloat('beta', 0.75);
36
const bias = node.attributes.getFloat('bias', 1.0);
37
const size = node.attributes.getInt('size');
39
return createAttributeWithCacheKey({ alpha, beta, bias, size });
42
const lrnProgramMetadata = {
45
inputTypes: [TextureType.unpacked],
48
function createLrnProgramInfo(inputs: Tensor[], attributes: LrnAttributes): ProgramInfo {
49
const C = inputs[0].dims[1];
50
const rank = inputs[0].dims.length;
51
const from = -Math.floor((attributes.size - 1) / 2);
52
const to = Math.ceil((attributes.size - 1) / 2);
53
const alpha = `float(${attributes.alpha}) / float(${attributes.size})`;
54
const bias = `float(${attributes.bias})`;
55
const beta = `float(${attributes.beta})`;
57
const shaderSource = `
58
float process(int indices[${rank}]) {
60
float x = _X(indices);
61
float square_sum = 0.0;
63
for (int i = ${from}; i <= ${to}; i++) {
65
if (c >= 0 && c < ${C}) {
67
float j = _X(indices);
71
return x / pow(${bias} + ${alpha} * square_sum, ${beta});
74
...lrnProgramMetadata,
75
cacheHint: attributes.cacheKey,
76
output: { dims: inputs[0].dims, type: inputs[0].type, textureType: TextureType.unpacked },
81
export function createLrnProgramInfoLoader(inputs: Tensor[], attributes: LrnAttributes): ProgramInfoLoader {
82
return { ...lrnProgramMetadata, cacheHint: attributes.cacheKey, get: () => createLrnProgramInfo(inputs, attributes) };
85
const validateInputs = (inputs: Tensor[]): void => {
86
if (!inputs || inputs.length !== 1) {
87
throw new Error('LRN requires 1 input.');
89
if (inputs[0].dims.length !== 4) {
90
throw new Error('currently only support LRN for input with "NCHW" format');
92
if (inputs[0].type !== 'float32') {
93
throw new Error('input should be float type');