onnxruntime

Форк
0
/
op-resolve-rules.ts 
156 строк · 6.7 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { argMax, argMin, parseArgMinMaxAttributes } from './ops/argminmax';
5
import { attention } from './ops/attention';
6
import { batchNorm } from './ops/batch-norm';
7
import { biasAdd } from './ops/bias-add';
8
import { biasSplitGelu } from './ops/bias-split-gelu';
9
import * as binaryOps from './ops/binary-op';
10
import { concat, parseConcatAttributes } from './ops/concat';
11
import { conv, parseConvAttributes } from './ops/conv';
12
import { convTranspose, parseConvTransposeAttributes } from './ops/conv-transpose';
13
import { cumsum, parseCumSumAttributes } from './ops/cumsum';
14
import { depthToSpace, parseDepthToSpaceAttributes } from './ops/depth-to-space';
15
import { einsum, parseEinsumAttributes } from './ops/einsum';
16
import { expand } from './ops/expand';
17
import { fastGelu } from './ops/fast-gelu';
18
import { gather, parseGatherAttributes } from './ops/gather';
19
import { gatherBlockQuantized, parseGatherBlockQuantizedAttributes } from './ops/gather-block-quantized';
20
import { gatherElements, parseGatherElementsAttributes } from './ops/gather-elements';
21
import { gemm, parseGemmAttributes } from './ops/gemm';
22
import { groupQueryAttention, parseGroupQueryAttentionAttributes } from './ops/group-query-attention';
23
import { instanceNorm } from './ops/instance-norm';
24
import { layerNorm } from './ops/layer-norm';
25
import { matMul } from './ops/matmul';
26
import { matMulNBits, parseMatMulNBitsAttributes } from './ops/matmulnbits';
27
import { multiHeadAttention, parseMultiHeadAttentionAttributes } from './ops/multihead-attention';
28
import { pad } from './ops/pad';
29
import * as pool from './ops/pool';
30
import { dequantizeLinear, parseDequantizeLinearAttributes } from './ops/quantize-linear';
31
import { range } from './ops/range';
32
import {
33
  reduceL1,
34
  reduceL2,
35
  reduceLogSum,
36
  reduceLogSumExp,
37
  reduceMax,
38
  reduceMean,
39
  reduceMin,
40
  reduceProd,
41
  reduceSum,
42
  reduceSumSquare,
43
} from './ops/reduce';
44
import { parseResizeAttributes, resize } from './ops/resize';
45
import { rotaryEmbedding } from './ops/rotary-embedding';
46
import { skipLayerNorm } from './ops/skip-layer-norm';
47
import { parseSliceAttributes, slice } from './ops/slice';
48
import { parseSoftmaxAttributes, softmax } from './ops/softmax';
49
import { parseSplitAttributes, split } from './ops/split';
50
import { tile } from './ops/tile';
51
import { parseTransposeAttributes, transpose } from './ops/transpose';
52
import * as unaryOps from './ops/unary-op';
53
import { where } from './ops/where';
54
import { ComputeContext } from './types';
55

56
export type RunFunction = (context: ComputeContext, attribute?: unknown) => void;
57
export type ParseAttributeFunction = (attributeRaw: unknown) => unknown;
58
export type OperatorImplementation = [RunFunction] | [RunFunction, ParseAttributeFunction];
59

60
export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new Map([
61
  ['Abs', [unaryOps.abs]],
62
  ['Acos', [unaryOps.acos]],
63
  ['Acosh', [unaryOps.acosh]],
64
  ['Add', [binaryOps.add]],
65
  ['ArgMax', [argMax, parseArgMinMaxAttributes]],
66
  ['ArgMin', [argMin, parseArgMinMaxAttributes]],
67
  ['Asin', [unaryOps.asin]],
68
  ['Asinh', [unaryOps.asinh]],
69
  ['Atan', [unaryOps.atan]],
70
  ['Atanh', [unaryOps.atanh]],
71
  ['Attention', [attention]],
72
  // TODO: support new attributes for AveragePool-10
73
  ['AveragePool', [pool.averagePool, pool.parseAveragePoolAttributes]],
74
  ['BatchNormalization', [batchNorm]],
75
  ['BiasAdd', [biasAdd]],
76
  ['BiasSplitGelu', [biasSplitGelu]],
77
  ['Cast', [unaryOps.cast, unaryOps.parseCastAttributes]],
78
  ['Ceil', [unaryOps.ceil]],
79
  ['Clip', [unaryOps.clip]],
80
  ['Concat', [concat, parseConcatAttributes]],
81
  ['Conv', [conv, parseConvAttributes]],
82
  ['ConvTranspose', [convTranspose, parseConvTransposeAttributes]],
83
  ['Cos', [unaryOps.cos]],
84
  ['Cosh', [unaryOps.cosh]],
85
  ['CumSum', [cumsum, parseCumSumAttributes]],
86
  ['DepthToSpace', [depthToSpace, parseDepthToSpaceAttributes]],
87
  ['DequantizeLinear', [dequantizeLinear, parseDequantizeLinearAttributes]],
88
  ['Div', [binaryOps.div]],
89
  ['Einsum', [einsum, parseEinsumAttributes]],
90
  ['Elu', [unaryOps.elu, unaryOps.parseAlphaAttributes]],
91
  ['Equal', [binaryOps.equal]],
92
  ['Erf', [unaryOps.erf]],
93
  ['Exp', [unaryOps.exp]],
94
  ['Expand', [expand]],
95
  ['FastGelu', [fastGelu]],
96
  ['Floor', [unaryOps.floor]],
97
  ['FusedConv', [conv, parseConvAttributes]],
98
  ['Gather', [gather, parseGatherAttributes]],
99
  ['GatherElements', [gatherElements, parseGatherElementsAttributes]],
100
  ['GatherBlockQuantized', [gatherBlockQuantized, parseGatherBlockQuantizedAttributes]],
101
  ['Gelu', [unaryOps.gelu]],
102
  ['Gemm', [gemm, parseGemmAttributes]],
103
  ['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],
104
  ['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
105
  ['Greater', [binaryOps.greater]],
106
  ['GreaterOrEqual', [binaryOps.greaterOrEqual]],
107
  ['GroupQueryAttention', [groupQueryAttention, parseGroupQueryAttentionAttributes]],
108
  ['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]],
109
  ['InstanceNormalization', [instanceNorm]],
110
  ['LayerNormalization', [layerNorm]],
111
  ['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]],
112
  ['Less', [binaryOps.less]],
113
  ['LessOrEqual', [binaryOps.lessOrEqual]],
114
  ['Log', [unaryOps.log]],
115
  ['MatMul', [matMul]],
116
  ['MatMulNBits', [matMulNBits, parseMatMulNBitsAttributes]],
117
  // TODO: support new attributes for MaxPool-8 and MaxPool-10
118
  ['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]],
119
  ['Mul', [binaryOps.mul]],
120
  ['MultiHeadAttention', [multiHeadAttention, parseMultiHeadAttentionAttributes]],
121
  ['Neg', [unaryOps.neg]],
122
  ['Not', [unaryOps.not]],
123
  ['Pad', [pad]],
124
  ['Pow', [binaryOps.pow]],
125
  ['QuickGelu', [unaryOps.quickgelu, unaryOps.parseAlphaAttributes]],
126
  ['Range', [range]],
127
  ['Reciprocal', [unaryOps.reciprocal]],
128
  ['ReduceMin', [reduceMin]],
129
  ['ReduceMean', [reduceMean]],
130
  ['ReduceMax', [reduceMax]],
131
  ['ReduceSum', [reduceSum]],
132
  ['ReduceProd', [reduceProd]],
133
  ['ReduceL1', [reduceL1]],
134
  ['ReduceL2', [reduceL2]],
135
  ['ReduceLogSum', [reduceLogSum]],
136
  ['ReduceLogSumExp', [reduceLogSumExp]],
137
  ['ReduceSumSquare', [reduceSumSquare]],
138
  ['Relu', [unaryOps.relu]],
139
  ['Resize', [resize, parseResizeAttributes]],
140
  ['RotaryEmbedding', [rotaryEmbedding]],
141
  ['Sigmoid', [unaryOps.sigmoid]],
142
  ['Sin', [unaryOps.sin]],
143
  ['Sinh', [unaryOps.sinh]],
144
  ['Slice', [slice, parseSliceAttributes]],
145
  ['SkipLayerNormalization', [skipLayerNorm]],
146
  ['Split', [split, parseSplitAttributes]],
147
  ['Sqrt', [unaryOps.sqrt]],
148
  ['Softmax', [softmax, parseSoftmaxAttributes]],
149
  ['Sub', [binaryOps.sub]],
150
  ['Tan', [unaryOps.tan]],
151
  ['Tanh', [unaryOps.tanh]],
152
  ['ThresholdedRelu', [unaryOps.thresholdedRelu, unaryOps.parseAlphaAttributes]],
153
  ['Tile', [tile]],
154
  ['Transpose', [transpose, parseTransposeAttributes]],
155
  ['Where', [where]],
156
]);
157

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.