onnxruntime

Форк
0
271 строка · 9.5 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { Env } from 'onnxruntime-common';
5

6
import { calculateTensorSizeInBytes, DataType } from '../wasm-common';
7

8
import type { OrtWasmModule } from '../wasm-types';
9

10
import { WebGpuBackend } from './backend-webgpu';
11
import { LOG_DEBUG } from './log';
12
import { TensorView } from './tensor-view';
13
import { ShapeUtil } from './util';
14
import { AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo } from './webgpu/types';
15

16
/* eslint-disable no-bitwise */
17

18
class TensorViewImpl implements TensorView {
19
  constructor(
20
    private module: OrtWasmModule,
21
    public readonly dataType: number,
22
    public readonly data: number,
23
    public readonly dims: readonly number[],
24
  ) {}
25

26
  getFloat32Array(): Float32Array {
27
    if (this.dataType !== DataType.float) {
28
      throw new Error('Invalid data type');
29
    }
30
    const elementCount = ShapeUtil.size(this.dims);
31
    return elementCount === 0
32
      ? new Float32Array()
33
      : new Float32Array(this.module.HEAP8.buffer, this.data, elementCount);
34
  }
35

36
  getBigInt64Array(): BigInt64Array {
37
    if (this.dataType !== DataType.int64) {
38
      throw new Error('Invalid data type');
39
    }
40
    const elementCount = ShapeUtil.size(this.dims);
41
    return elementCount === 0
42
      ? new BigInt64Array()
43
      : new BigInt64Array(this.module.HEAP8.buffer, this.data, elementCount);
44
  }
45

46
  getInt32Array(): Int32Array {
47
    if (this.dataType !== DataType.int32) {
48
      throw new Error('Invalid data type');
49
    }
50
    const elementCount = ShapeUtil.size(this.dims);
51
    return elementCount === 0 ? new Int32Array() : new Int32Array(this.module.HEAP8.buffer, this.data, elementCount);
52
  }
53

54
  getUint16Array(): Uint16Array {
55
    if (this.dataType !== DataType.float16 && this.dataType !== DataType.uint16) {
56
      throw new Error('Invalid data type');
57
    }
58
    const elementCount = ShapeUtil.size(this.dims);
59
    return elementCount === 0 ? new Uint16Array() : new Uint16Array(this.module.HEAP8.buffer, this.data, elementCount);
60
  }
61

62
  reshape(newDims: readonly number[]): TensorView {
63
    if (ShapeUtil.size(newDims) !== ShapeUtil.size(this.dims)) {
64
      throw new Error('Invalid new shape');
65
    }
66
    return new TensorViewImpl(this.module, this.dataType, this.data, newDims);
67
  }
68
}
69

70
class ComputeContextImpl implements ComputeContext {
71
  readonly adapterInfo: AdapterInfo;
72
  readonly opKernelContext: number;
73
  readonly inputs: readonly TensorView[];
74
  readonly outputCount: number;
75
  get kernelCustomData(): { [key: string]: unknown } {
76
    return this.backend.currentKernelCustomData;
77
  }
78
  get customDataBuffer(): Uint8Array {
79
    return this.module.HEAPU8.subarray(this.customDataOffset, this.customDataOffset + this.customDataSize);
80
  }
81
  private customDataOffset = 0;
82
  private customDataSize = 0;
83
  constructor(
84
    private module: OrtWasmModule,
85
    private backend: WebGpuBackend,
86
    contextDataOffset: number,
87
  ) {
88
    this.adapterInfo = backend.adapterInfo;
89
    const heapU32 = module.HEAPU32;
90

91
    // extract context data
92
    let dataIndex = contextDataOffset >>> 2;
93
    this.opKernelContext = heapU32[dataIndex++];
94
    const inputCount = heapU32[dataIndex++];
95
    this.outputCount = heapU32[dataIndex++];
96
    this.customDataOffset = heapU32[dataIndex++];
97
    this.customDataSize = heapU32[dataIndex++];
98

99
    const inputs: TensorView[] = [];
100
    for (let i = 0; i < inputCount; i++) {
101
      const dataType = heapU32[dataIndex++];
102
      const data = heapU32[dataIndex++];
103
      const dim = heapU32[dataIndex++];
104
      const dims: number[] = [];
105
      for (let d = 0; d < dim; d++) {
106
        dims.push(heapU32[dataIndex++]);
107
      }
108
      inputs.push(new TensorViewImpl(module, dataType, data, dims));
109
    }
110
    this.inputs = inputs;
111
  }
112

113
  getMaxComputeWorkgroupSizes(): [number, number, number] {
114
    return [
115
      this.backend.device.limits.maxComputeWorkgroupSizeX,
116
      this.backend.device.limits.maxComputeWorkgroupSizeY,
117
      this.backend.device.limits.maxComputeWorkgroupSizeZ,
118
    ];
119
  }
120

121
  getMaxComputeWorkgroupStoragesize(): number {
122
    return this.backend.device.limits.maxComputeWorkgroupStorageSize;
123
  }
124

125
  compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[] {
126
    // prepare inputs. inputs should always be valid data.
127
    const mappedInputs =
128
      inputsOutputsMapping?.inputs?.map((i) => (typeof i === 'number' ? this.inputs[i] : i)) ?? this.inputs;
129
    // prepare outputs.
130
    const outputIndices = inputsOutputsMapping?.outputs ?? [];
131
    const createKernelOutput = (index: number, dataType: number, dims: readonly number[]): TensorView =>
132
      new TensorViewImpl(this.module, dataType, this.output(index, dims), dims);
133
    const createTemporaryOutput = (dataType: number, dims: readonly number[]): TensorView => {
134
      const bufferSize = calculateTensorSizeInBytes(dataType, dims);
135
      if (!bufferSize) {
136
        throw new Error(`Unsupported data type: ${dataType}`);
137
      }
138
      const gpuDataId = bufferSize > 0 ? this.backend.gpuDataManager.create(bufferSize).id : 0;
139
      return new TensorViewImpl(this.module, dataType, gpuDataId, dims);
140
    };
141
    return this.backend.run(
142
      program,
143
      mappedInputs,
144
      outputIndices,
145
      createKernelOutput,
146
      createTemporaryOutput,
147
      this.outputCount,
148
    );
149
  }
150

151
  output(index: number, dims: readonly number[]): number {
152
    const stack = this.module.stackSave();
153
    try {
154
      const data = this.module.stackAlloc((1 + dims.length) * 4 /* sizeof(size_t) */);
155
      let offset = data >> 2;
156
      this.module.HEAPU32[offset++] = dims.length;
157
      for (let i = 0; i < dims.length; i++) {
158
        this.module.HEAPU32[offset++] = dims[i];
159
      }
160
      return this.module._JsepOutput!(this.opKernelContext, index, data);
161
    } catch (e) {
162
      throw new Error(
163
        `Failed to generate kernel's output[${index}] with dims [${dims}]. ` +
164
          'If you are running with pre-allocated output, please make sure the output type/dims are correct. ' +
165
          `Error: ${e}`,
166
      );
167
    } finally {
168
      this.module.stackRestore(stack);
169
    }
170
  }
171
}
172

173
/**
174
 * Initialize JSEP with WebGPU backend.
175
 *
176
 * This function will be called after the WebAssembly module is loaded and initialized ("_OrtInit" is called), once for
177
 * each of the following EPs if they are specified:
178
 * - "webgpu"
179
 * - "webnn"
180
 *
181
 * For WebGPU, this function expects:
182
 *  - WebGPU is enabled in build (BUILD_DEFS.DISABLE_JSEP === false).
183
 *  - WebGPU is available in current environment. (a valid GPUAdapter is passed in)
184
 *
185
 * For WebNN, this function expects:
186
 * - WebNN is enabled in build (BUILD_DEFS.DISABLE_JSEP === false).
187
 * - WebNN is available in current environment. (navigator.ml is not undefined)
188
 *
189
 * If the WebAssembly module is not built with JSEP support, this function will throw an error. This will invalidate
190
 * 'webgpu'/'webnn' backend.
191
 *
192
 * @param name - the name of the EP, either "webgpu" or "webnn"
193
 * @param module - the ORT WebAssembly module
194
 * @param env - the ORT environment variable (ort.env)
195
 * @param gpuAdapter - the pre-created GPU adapter
196
 */
197
export const init = async (
198
  name: 'webgpu' | 'webnn',
199
  module: OrtWasmModule,
200
  env: Env,
201
  gpuAdapter?: GPUAdapter,
202
): Promise<void> => {
203
  const jsepInit = module.jsepInit;
204
  if (!jsepInit) {
205
    throw new Error('Failed to initialize JSEP. The WebAssembly module is not built with JSEP support.');
206
  }
207

208
  if (name === 'webgpu') {
209
    const backend = new WebGpuBackend();
210
    await backend.initialize(env, gpuAdapter!);
211

212
    jsepInit('webgpu', [
213
      // backend
214
      backend,
215

216
      // jsepAlloc()
217
      (size: number) => backend.alloc(size),
218

219
      // jsepFree()
220
      (ptr: number) => backend.free(ptr),
221

222
      // jsepCopy(src, dst, size, isSourceGpu)
223
      (src: number, dst: number, size: number, isSourceGpu = false) => {
224
        if (isSourceGpu) {
225
          LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyGpuToGpu: src=${src}, dst=${dst}, size=${size}`);
226
          backend.memcpy(src, dst);
227
        } else {
228
          LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`);
229
          const data = module.HEAPU8.subarray(src >>> 0, (src >>> 0) + size);
230
          backend.upload(dst, data);
231
        }
232
      },
233

234
      // jsepCopyAsync(src, dst, size)
235
      async (gpuDataId: number, dataOffset: number, size: number): Promise<void> => {
236
        LOG_DEBUG(
237
          'verbose',
238
          () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`,
239
        );
240

241
        await backend.download(gpuDataId, () => module.HEAPU8.subarray(dataOffset >>> 0, (dataOffset >>> 0) + size));
242
      },
243

244
      // jsepCreateKernel
245
      (kernelType: string, kernelId: number, attribute: unknown) =>
246
        backend.createKernel(kernelType, kernelId, attribute, module.UTF8ToString(module._JsepGetNodeName!(kernelId))),
247

248
      // jsepReleaseKernel
249
      (kernel: number) => backend.releaseKernel(kernel),
250

251
      // jsepRun
252
      (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array<Promise<string | null>>) => {
253
        LOG_DEBUG(
254
          'verbose',
255
          () =>
256
            `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${contextDataOffset}`,
257
        );
258
        const context = new ComputeContextImpl(module, backend, contextDataOffset);
259
        return backend.computeKernel(kernel, context, errors);
260
      },
261
      // jsepCaptureBegin
262
      () => backend.captureBegin(),
263
      // jsepCaptureEnd
264
      () => backend.captureEnd(),
265
      // jsepReplay
266
      () => backend.replay(),
267
    ]);
268
  } else {
269
    jsepInit('webnn');
270
  }
271
};
272

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.