onnxruntime
149 строк · 4.0 Кб
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4import { WebGLBackend } from './backends/backend-webgl';
5import { Graph } from './graph';
6import { Operator } from './operators';
7import { OpSet } from './opset';
8import { Session } from './session';
9
10export interface InferenceHandler {
11/**
12* dispose the inference handler. it will be called as the last step in Session.run()
13*/
14dispose(): void;
15}
16
17export interface SessionHandler {
18/**
19* transform the graph at initialization time
20* @param graphTransformer the graph transformer to manipulate the model graph
21*/
22transformGraph?(graphTransformer: Graph.Transformer): void;
23
24/**
25* create an instance of InferenceHandler to use in a Session.run() call
26*/
27createInferenceHandler(): InferenceHandler;
28
29/**
30* dispose the session handler. it will be called when a session is being disposed explicitly
31*/
32dispose(): void;
33
34/**
35* Resolves the operator from the name and opset version; backend specific
36* @param node the node to resolve
37* @param opsets a list of opsets that exported from the model
38* @param graph the completely initialized graph
39*/
40resolve(node: Graph.Node, opsets: readonly OpSet[], graph: Graph): Operator;
41
42/**
43* This method let's the sessionHandler know that the graph initialization is complete
44* @param graph the completely initialized graph
45*/
46onGraphInitialized?(graph: Graph): void;
47
48/**
49* a reference to the corresponding backend
50*/
51readonly backend: Backend;
52
53/**
54* a reference to the session context
55*/
56readonly context: Session.Context;
57}
58
59export interface Backend {
60/**
61* initialize the backend. will be called only once, when the first time the
62* backend it to be used
63*/
64initialize(): boolean | Promise<boolean>;
65
66/**
67* create an instance of SessionHandler to use in a Session object's lifecycle
68*/
69createSessionHandler(context: Session.Context): SessionHandler;
70
71/**
72* dispose the backend. currently this will not be called
73*/
74dispose(): void;
75}
76
77// caches all initialized backend instances
78const backendsCache: Map<string, Backend> = new Map();
79
80export const backend: { [name: string]: Backend } = {
81webgl: new WebGLBackend(),
82};
83
84/**
85* Resolve a reference to the backend. If a hint is specified, the corresponding
86* backend will be used.
87*/
88export async function resolveBackend(hint?: string | readonly string[]): Promise<Backend> {
89if (!hint) {
90return resolveBackend(['webgl']);
91} else {
92const hints = typeof hint === 'string' ? [hint] : hint;
93
94for (const backendHint of hints) {
95const cache = backendsCache.get(backendHint);
96if (cache) {
97return cache;
98}
99
100const backend = await tryLoadBackend(backendHint);
101if (backend) {
102return backend;
103}
104}
105}
106
107throw new Error('no available backend to use');
108}
109
110async function tryLoadBackend(backendHint: string): Promise<Backend | undefined> {
111const backendObj = backend;
112
113if (typeof backendObj[backendHint] !== 'undefined' && isBackend(backendObj[backendHint])) {
114const backend = backendObj[backendHint];
115let init = backend.initialize();
116if (typeof init === 'object' && 'then' in init) {
117init = await init;
118}
119if (init) {
120backendsCache.set(backendHint, backend);
121return backend;
122}
123}
124
125return undefined;
126}
127
128function isBackend(obj: unknown) {
129// eslint-disable-next-line @typescript-eslint/no-explicit-any
130const o = obj as any;
131
132// check if an object is a Backend instance
133if (
134'initialize' in o &&
135typeof o.initialize === 'function' && // initialize()
136'createSessionHandler' in o &&
137typeof o.createSessionHandler === 'function' && // createSessionHandler()
138'dispose' in o &&
139typeof o.dispose === 'function' // dispose()
140) {
141return true;
142}
143
144return false;
145}
146
147export type BackendType = Backend;
148export type SessionHandlerType = ReturnType<BackendType['createSessionHandler']>;
149export type InferenceHandlerType = ReturnType<SessionHandlerType['createInferenceHandler']>;
150