onnxruntime

Форк
0
/
wasm-training-core-impl.ts 
631 строка · 20.2 Кб
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
// Licensed under the MIT License.
3

4
import { InferenceSession, Tensor } from 'onnxruntime-common';
5

6
import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages';
7
import { setRunOptions } from './run-options';
8
import { setSessionOptions } from './session-options';
9
import {
10
  dataLocationStringToEnum,
11
  tensorDataTypeEnumToString,
12
  tensorDataTypeStringToEnum,
13
  tensorTypeToTypedArrayConstructor,
14
} from './wasm-common';
15
import { prepareInputOutputTensor } from './wasm-core-impl';
16
import { getInstance } from './wasm-factory';
17
import { checkLastError } from './wasm-utils';
18

19
const NO_TRAIN_FUNCS_MSG =
20
  "Built without training API's enabled. Use the onnxruntime-web/training import for training " +
21
  'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' +
22
  'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.';
23

24
/**
25
 * Runs the checkLastError function which will throw an error, if the provided error code matches the specified
26
 * pattern for an error code.
27
 * @param errCode number to evaluated for if it's an error
28
 * @param message message to pass into checkLastError
29
 * @param checkNeqZero when true, treats not equal to zero as an error.
30
 *                     When false, treats equal to zero as an error.
31
 */
32
const ifErrCodeCheckLastError = (errCode: number, message: string, checkNeqZero = true) => {
33
  if (checkNeqZero && errCode !== 0) {
34
    checkLastError(message);
35
  } else if (!checkNeqZero && errCode === 0) {
36
    checkLastError(message);
37
  }
38
};
39

40
export const createCheckpointHandle = (checkpointData: SerializableInternalBuffer): number => {
41
  const wasm = getInstance();
42

43
  const [checkpointDataOffset, checkpointDataLength] = checkpointData;
44
  let checkpointHandle = 0;
45

46
  try {
47
    if (wasm._OrtTrainingLoadCheckpoint) {
48
      checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength);
49
    } else {
50
      throw new Error(NO_TRAIN_FUNCS_MSG);
51
    }
52

53
    ifErrCodeCheckLastError(checkpointHandle, 'Error occurred when trying to create a CheckpointState', false);
54
    return checkpointHandle;
55
  } catch (e) {
56
    if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) {
57
      wasm._OrtTrainingReleaseCheckpoint(checkpointHandle);
58
    }
59
    throw e;
60
  } finally {
61
    // free buffer from wasm heap
62
    wasm._OrtFree(checkpointData[0]);
63
  }
64
};
65

66
const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolean): [number, number] => {
67
  const wasm = getInstance();
68
  const stack = wasm.stackSave();
69
  try {
70
    const dataOffset = wasm.stackAlloc(8);
71
    if (wasm._OrtTrainingGetModelInputOutputCount) {
72
      const errorCode = wasm._OrtTrainingGetModelInputOutputCount(
73
        trainingSessionId,
74
        dataOffset,
75
        dataOffset + 4,
76
        isEvalModel,
77
      );
78
      ifErrCodeCheckLastError(errorCode, "Can't get session input/output count.");
79
      return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]];
80
    } else {
81
      throw new Error(NO_TRAIN_FUNCS_MSG);
82
    }
83
  } finally {
84
    wasm.stackRestore(stack);
85
  }
86
};
87

88
const getModelInputOutputNamesLoop = (
89
  trainingSessionId: number,
90
  count: number,
91
  isInput: boolean,
92
  isEvalModel: boolean,
93
): string[] => {
94
  const names = [];
95
  const wasm = getInstance();
96

97
  for (let i = 0; i < count; i++) {
98
    if (wasm._OrtTrainingGetModelInputOutputName) {
99
      const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel);
100
      ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false);
101

102
      names.push(wasm.UTF8ToString(name));
103
      wasm._free(name);
104
    } else {
105
      throw new Error(NO_TRAIN_FUNCS_MSG);
106
    }
107
  }
108
  return names;
109
};
110

111
export const getModelInputOutputNames = (trainingSessionId: number, isEvalModel: boolean): [string[], string[]] => {
112
  let inputNames: string[] = [];
113
  let outputNames: string[] = [];
114

115
  const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, isEvalModel);
116

117
  inputNames = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, isEvalModel);
118
  outputNames = getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, isEvalModel);
119

120
  return [inputNames, outputNames];
121
};
122

123
export const createTrainingSessionHandle = (
124
  checkpointHandle: number,
125
  trainModelData: SerializableInternalBuffer,
126
  evalModelData: SerializableInternalBuffer,
127
  optimizerModelData: SerializableInternalBuffer,
128
  options: InferenceSession.SessionOptions,
129
): number => {
130
  const wasm = getInstance();
131

132
  let trainingSessionHandle = 0;
133
  let sessionOptionsHandle = 0;
134
  let allocs: number[] = [];
135

136
  try {
137
    [sessionOptionsHandle, allocs] = setSessionOptions(options);
138
    if (wasm._OrtTrainingCreateSession) {
139
      trainingSessionHandle = wasm._OrtTrainingCreateSession(
140
        sessionOptionsHandle,
141
        checkpointHandle,
142
        trainModelData[0],
143
        trainModelData[1],
144
        evalModelData[0],
145
        evalModelData[1],
146
        optimizerModelData[0],
147
        optimizerModelData[1],
148
      );
149
    } else {
150
      throw new Error(NO_TRAIN_FUNCS_MSG);
151
    }
152

153
    ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false);
154
    return trainingSessionHandle;
155
  } catch (e) {
156
    if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) {
157
      wasm._OrtTrainingReleaseSession(trainingSessionHandle);
158
    }
159
    throw e;
160
  } finally {
161
    wasm._free(trainModelData[0]);
162
    wasm._free(evalModelData[0]);
163
    wasm._free(optimizerModelData[0]);
164

165
    if (sessionOptionsHandle !== 0) {
166
      wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
167
    }
168
    allocs.forEach((alloc) => wasm._free(alloc));
169
  }
170
};
171

172
/**
173
 * Prepares input and output tensors by creating the tensors in the WASM side then creates a list of the handles of the
174
 * WASM tensors.
175
 *
176
 * @param trainingSessionId
177
 * @param indices for each tensor, the index of the input or output name that the tensor corresponds with
178
 * @param tensors list of TensorMetaData
179
 * @param tensorHandles should pass in an empty list of numbers; modified in-place by this method & stores the resulting
180
 *                      handles of the allocated tensors on the heap
181
 * @param inputOutputAllocs modified in-place by this method
182
 * @param indexAdd constant to add to the index that is passed to prepareInputOutputTensor
183
 */
184
const createAndAllocateTensors = (
185
  trainingSessionId: number,
186
  indices: number[],
187
  tensors: Array<TensorMetadata | null>,
188
  tensorHandles: number[],
189
  inputOutputAllocs: number[],
190
  indexAdd: number,
191
) => {
192
  const count = indices.length;
193

194
  // creates the tensors
195
  for (let i = 0; i < count; i++) {
196
    prepareInputOutputTensor(tensors[i], tensorHandles, inputOutputAllocs, trainingSessionId, indexAdd + indices[i]);
197
  }
198

199
  // moves to heap
200
  const wasm = getInstance();
201
  const valuesOffset = wasm.stackAlloc(count * 4);
202
  let valuesIndex = valuesOffset / 4;
203
  for (let i = 0; i < count; i++) {
204
    wasm.HEAPU32[valuesIndex++] = tensorHandles[i];
205
  }
206

207
  return valuesOffset;
208
};
209

210
/**
211
 * Retrieves the information from the output tensor handles, copies to an array, and frees the WASM information
212
 * associated with the tensor handle.
213
 *
214
 * @param outputValuesOffset
215
 * @param outputCount
216
 * @returns list of TensorMetadata retrieved from the output handles.
217
 */
218
const moveOutputToTensorMetadataArr = (
219
  outputValuesOffset: number,
220
  outputCount: number,
221
  outputTensorHandles: number[],
222
  outputTensors: Array<TensorMetadata | null>,
223
) => {
224
  const wasm = getInstance();
225
  const output: TensorMetadata[] = [];
226

227
  for (let i = 0; i < outputCount; i++) {
228
    const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i];
229
    if (tensor === outputTensorHandles[i]) {
230
      // output tensor is pre-allocated. no need to copy data.
231
      output.push(outputTensors[i]!);
232
      continue;
233
    }
234

235
    const beforeGetTensorDataStack = wasm.stackSave();
236
    // stack allocate 4 pointer value
237
    const tensorDataOffset = wasm.stackAlloc(4 * 4);
238

239
    let type: Tensor.Type | undefined,
240
      dataOffset = 0;
241
    try {
242
      const errorCode = wasm._OrtGetTensorData(
243
        tensor,
244
        tensorDataOffset,
245
        tensorDataOffset + 4,
246
        tensorDataOffset + 8,
247
        tensorDataOffset + 12,
248
      );
249
      ifErrCodeCheckLastError(errorCode, `Can't access output tensor data on index ${i}.`);
250

251
      let tensorDataIndex = tensorDataOffset / 4;
252
      const dataType = wasm.HEAPU32[tensorDataIndex++];
253
      dataOffset = wasm.HEAPU32[tensorDataIndex++];
254
      const dimsOffset = wasm.HEAPU32[tensorDataIndex++];
255
      const dimsLength = wasm.HEAPU32[tensorDataIndex++];
256
      const dims = [];
257
      for (let i = 0; i < dimsLength; i++) {
258
        dims.push(wasm.HEAPU32[dimsOffset / 4 + i]);
259
      }
260
      wasm._OrtFree(dimsOffset);
261

262
      const size = dims.reduce((a, b) => a * b, 1);
263
      type = tensorDataTypeEnumToString(dataType);
264

265
      if (type === 'string') {
266
        const stringData: string[] = [];
267
        let dataIndex = dataOffset / 4;
268
        for (let i = 0; i < size; i++) {
269
          const offset = wasm.HEAPU32[dataIndex++];
270
          const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset;
271
          stringData.push(wasm.UTF8ToString(offset, maxBytesToRead));
272
        }
273
        output.push([type, dims, stringData, 'cpu']);
274
      } else {
275
        const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type);
276
        const data = new typedArrayConstructor(size);
277
        new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set(
278
          wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength),
279
        );
280
        output.push([type, dims, data, 'cpu']);
281
      }
282
    } finally {
283
      wasm.stackRestore(beforeGetTensorDataStack);
284
      if (type === 'string' && dataOffset) {
285
        wasm._free(dataOffset);
286
      }
287
      wasm._OrtReleaseTensor(tensor);
288
    }
289
  }
290

291
  return output;
292
};
293

294
export const lazyResetGrad = async (trainingSessionId: number): Promise<void> => {
295
  const wasm = getInstance();
296

297
  if (wasm._OrtTrainingLazyResetGrad) {
298
    const errorCode = wasm._OrtTrainingLazyResetGrad(trainingSessionId);
299
    ifErrCodeCheckLastError(errorCode, "Can't call lazyResetGrad.");
300
  } else {
301
    throw new Error(NO_TRAIN_FUNCS_MSG);
302
  }
303
};
304

305
export const runTrainStep = async (
306
  trainingSessionId: number,
307
  inputIndices: number[],
308
  inputTensors: TensorMetadata[],
309
  outputIndices: number[],
310
  outputTensors: Array<TensorMetadata | null>,
311
  options: InferenceSession.RunOptions,
312
): Promise<TensorMetadata[]> => {
313
  const wasm = getInstance();
314

315
  const inputCount = inputIndices.length;
316
  const outputCount = outputIndices.length;
317

318
  let runOptionsHandle = 0;
319
  let runOptionsAllocs: number[] = [];
320

321
  const inputTensorHandles: number[] = [];
322
  const outputTensorHandles: number[] = [];
323
  const inputOutputAllocs: number[] = [];
324

325
  const beforeRunStack = wasm.stackSave();
326

327
  try {
328
    // prepare parameters by moving them to heap
329
    [runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
330

331
    // handle inputs -- you don't want anything added to the index
332
    const inputValuesOffset = createAndAllocateTensors(
333
      trainingSessionId,
334
      inputIndices,
335
      inputTensors,
336
      inputTensorHandles,
337
      inputOutputAllocs,
338
      0,
339
    );
340
    // handle outputs
341
    // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor
342
    const outputValuesOffset = createAndAllocateTensors(
343
      trainingSessionId,
344
      outputIndices,
345
      outputTensors,
346
      outputTensorHandles,
347
      inputOutputAllocs,
348
      inputCount,
349
    );
350

351
    if (wasm._OrtTrainingRunTrainStep) {
352
      const errorCode = wasm._OrtTrainingRunTrainStep(
353
        trainingSessionId,
354
        inputValuesOffset,
355
        inputCount,
356
        outputValuesOffset,
357
        outputCount,
358
        runOptionsHandle,
359
      );
360
      ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingRunTrainStep in the WebAssembly layer');
361
    } else {
362
      throw new Error(NO_TRAIN_FUNCS_MSG);
363
    }
364

365
    return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors);
366
  } finally {
367
    wasm.stackRestore(beforeRunStack);
368

369
    inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
370
    outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
371
    inputOutputAllocs.forEach((p) => wasm._free(p));
372

373
    if (runOptionsHandle !== 0) {
374
      wasm._OrtReleaseRunOptions(runOptionsHandle);
375
    }
376
    runOptionsAllocs.forEach((p) => wasm._free(p));
377
  }
378
};
379

380
export const runOptimizerStep = async (
381
  trainingSessionId: number,
382
  options: InferenceSession.RunOptions,
383
): Promise<void> => {
384
  const wasm = getInstance();
385

386
  let runOptionsHandle = 0;
387
  let runOptionsAllocs: number[] = [];
388

389
  try {
390
    [runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
391

392
    if (wasm._OrtTrainingOptimizerStep) {
393
      const errCode = wasm._OrtTrainingOptimizerStep(trainingSessionId, runOptionsHandle);
394
      ifErrCodeCheckLastError(errCode, 'Failed to call OrtTrainingOptimizerStep in the WebAssembly layer');
395
    } else {
396
      throw new Error(NO_TRAIN_FUNCS_MSG);
397
    }
398
  } finally {
399
    if (runOptionsHandle !== 0) {
400
      wasm._OrtReleaseRunOptions(runOptionsHandle);
401
    }
402
    runOptionsAllocs.forEach((p) => wasm._free(p));
403
  }
404
};
405

406
export const runEvalStep = async (
407
  trainingSessionId: number,
408
  inputIndices: number[],
409
  inputTensors: TensorMetadata[],
410
  outputIndices: number[],
411
  outputTensors: Array<TensorMetadata | null>,
412
  options: InferenceSession.RunOptions,
413
): Promise<TensorMetadata[]> => {
414
  const wasm = getInstance();
415

416
  const inputCount = inputIndices.length;
417
  const outputCount = outputIndices.length;
418

419
  let runOptionsHandle = 0;
420
  let runOptionsAllocs: number[] = [];
421

422
  const inputTensorHandles: number[] = [];
423
  const outputTensorHandles: number[] = [];
424
  const inputOutputAllocs: number[] = [];
425

426
  const beforeRunStack = wasm.stackSave();
427

428
  try {
429
    // prepare parameters by moving them to heap
430
    [runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
431

432
    // handle inputs -- you don't want anything added to the index
433
    const inputValuesOffset = createAndAllocateTensors(
434
      trainingSessionId,
435
      inputIndices,
436
      inputTensors,
437
      inputTensorHandles,
438
      inputOutputAllocs,
439
      0,
440
    );
441
    // handle outputs
442
    // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor
443
    const outputValuesOffset = createAndAllocateTensors(
444
      trainingSessionId,
445
      outputIndices,
446
      outputTensors,
447
      outputTensorHandles,
448
      inputOutputAllocs,
449
      inputCount,
450
    );
451

452
    if (wasm._OrtTrainingEvalStep) {
453
      const errorCode = wasm._OrtTrainingEvalStep(
454
        trainingSessionId,
455
        inputValuesOffset,
456
        inputCount,
457
        outputValuesOffset,
458
        outputCount,
459
        runOptionsHandle,
460
      );
461

462
      ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingEvalStep in the WebAssembly layer');
463
    } else {
464
      throw new Error(NO_TRAIN_FUNCS_MSG);
465
    }
466

467
    return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors);
468
  } finally {
469
    wasm.stackRestore(beforeRunStack);
470

471
    inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
472
    outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
473
    inputOutputAllocs.forEach((p) => wasm._free(p));
474

475
    if (runOptionsHandle !== 0) {
476
      wasm._OrtReleaseRunOptions(runOptionsHandle);
477
    }
478
    runOptionsAllocs.forEach((p) => wasm._free(p));
479
  }
480
};
481

482
export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => {
483
  const wasm = getInstance();
484
  const stack = wasm.stackSave();
485

486
  try {
487
    const sizeOffset = wasm.stackAlloc(4);
488
    if (wasm._OrtTrainingGetParametersSize) {
489
      const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly);
490
      ifErrCodeCheckLastError(errorCode, "Can't get parameters size");
491

492
      return wasm.HEAP32[sizeOffset / 4];
493
    } else {
494
      throw new Error(NO_TRAIN_FUNCS_MSG);
495
    }
496
  } finally {
497
    wasm.stackRestore(stack);
498
  }
499
};
500

501
export const getContiguousParameters = async (
502
  trainingSessionId: number,
503
  trainableOnly: boolean,
504
): Promise<TensorMetadata> => {
505
  const wasm = getInstance();
506
  const stack = wasm.stackSave();
507

508
  const tensorTypeAsString = 'float32';
509
  const locationAsString = 'cpu';
510

511
  const parametersSize = getParametersSize(trainingSessionId, trainableOnly);
512
  let tensor = 0;
513

514
  // allocates a buffer of the correct size on the WASM heap
515
  const paramsByteLength = 4 * parametersSize;
516
  const paramsOffset = wasm._malloc(paramsByteLength);
517

518
  // handles the dimensions-related createTensor parameters
519
  const dims = [parametersSize];
520

521
  const dimsOffset = wasm.stackAlloc(4);
522
  const dimsIndex = dimsOffset / 4;
523
  wasm.HEAP32[dimsIndex] = parametersSize;
524

525
  try {
526
    // wraps allocated array in a tensor
527
    tensor = wasm._OrtCreateTensor(
528
      tensorDataTypeStringToEnum(tensorTypeAsString),
529
      paramsOffset,
530
      paramsByteLength,
531
      dimsOffset,
532
      dims.length,
533
      dataLocationStringToEnum(locationAsString),
534
    );
535
    ifErrCodeCheckLastError(
536
      tensor,
537
      `Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`,
538
      false,
539
    );
540

541
    if (wasm._OrtTrainingCopyParametersToBuffer) {
542
      const errCode = wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly);
543
      ifErrCodeCheckLastError(errCode, "Can't get contiguous parameters.");
544
    } else {
545
      throw new Error(NO_TRAIN_FUNCS_MSG);
546
    }
547

548
    // copies from WASM memory to a JavaScript typed array, which is then put into a TensorMetadata object
549
    const typedArrayConstructor = tensorTypeToTypedArrayConstructor(tensorTypeAsString);
550
    const data = new typedArrayConstructor(parametersSize);
551
    const output: TensorMetadata[] = [];
552
    new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set(
553
      wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength),
554
    );
555
    output.push([tensorTypeAsString, dims, data, locationAsString]);
556
    if (output.length !== 1) {
557
      throw new Error(`something unexpected happened in the getContiguousParameters function. Expected output length of
558
     one, got ${output.length}`);
559
    } else {
560
      return output[0];
561
    }
562
  } finally {
563
    if (tensor !== 0) {
564
      wasm._OrtReleaseTensor(tensor);
565
    }
566
    wasm._free(paramsOffset);
567
    wasm._free(dimsOffset);
568
    wasm.stackRestore(stack);
569
  }
570
};
571

572
export const loadParametersBuffer = async (
573
  trainingSessionId: number,
574
  buffer: Uint8Array,
575
  trainableOnly: boolean,
576
): Promise<void> => {
577
  const wasm = getInstance();
578
  const stack = wasm.stackSave();
579

580
  const tensorTypeAsString = 'float32';
581
  const locationAsString = 'cpu';
582

583
  // allocates & copies JavaScript buffer to WASM heap
584
  const bufferByteLength = buffer.length;
585
  const bufferCount = bufferByteLength / 4;
586
  const bufferOffset = wasm._malloc(bufferByteLength);
587
  wasm.HEAPU8.set(buffer, bufferOffset);
588

589
  // allocates and handles moving dimensions information to WASM memory
590
  const dimsOffset = wasm.stackAlloc(4);
591
  wasm.HEAP32[dimsOffset / 4] = bufferCount;
592
  const dimsLength = 1;
593
  let tensor = 0;
594

595
  try {
596
    tensor = wasm._OrtCreateTensor(
597
      tensorDataTypeStringToEnum(tensorTypeAsString),
598
      bufferOffset,
599
      bufferByteLength,
600
      dimsOffset,
601
      dimsLength,
602
      dataLocationStringToEnum(locationAsString),
603
    );
604
    ifErrCodeCheckLastError(tensor, `Can't create tensor for input/output. session=${trainingSessionId}`, false);
605

606
    if (wasm._OrtTrainingCopyParametersFromBuffer) {
607
      const errCode = wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly);
608
      ifErrCodeCheckLastError(errCode, "Can't copy buffer to parameters.");
609
    } else {
610
      throw new Error(NO_TRAIN_FUNCS_MSG);
611
    }
612
  } finally {
613
    if (tensor !== 0) {
614
      wasm._OrtReleaseTensor(tensor);
615
    }
616
    wasm.stackRestore(stack);
617
    wasm._free(bufferOffset);
618
    wasm._free(dimsOffset);
619
  }
620
};
621

622
export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number): void => {
623
  const wasm = getInstance();
624

625
  if (wasm._OrtTrainingReleaseSession) {
626
    wasm._OrtTrainingReleaseSession(sessionId);
627
  }
628
  if (wasm._OrtTrainingReleaseCheckpoint) {
629
    wasm._OrtTrainingReleaseCheckpoint(checkpointId);
630
  }
631
};
632

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.