3
from datetime import timedelta
4
from typing import Any, Generic, overload, TypeVar
7
from torch._C import Future
8
from torch._C._autograd import ProfilerEvent
9
from torch._C._distributed_c10d import Store
10
from torch._C._profiler import ProfilerConfig
14
_DEFAULT_INIT_METHOD: str
15
_DEFAULT_NUM_WORKER_THREADS: int
16
_UNSET_RPC_TIMEOUT: float
17
_DEFAULT_RPC_TIMEOUT_SEC: float
21
class RpcBackendOptions:
26
rpc_timeout: float = ...,
27
init_method: str = ...,
31
def __init__(self, name: str, worker_id: int) -> None: ...
33
def name(self) -> str: ...
35
def id(self) -> int: ...
36
def __eq__(self, other: object) -> bool: ...
39
def join(self, shutdown: bool = False, timeout: float = 0): ...
41
def shutdown(self): ...
43
def get_worker_info(self) -> WorkerInfo: ...
45
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
46
def get_worker_infos(self) -> list[WorkerInfo]: ...
47
def _get_device_map(self, dst: WorkerInfo) -> dict[torch.device, torch.device]: ...
48
def get_debug_info(self) -> dict[str, str]: ...
49
def get_metrics(self) -> dict[str, str]: ...
51
class PyRRef(Generic[_T]):
52
def __init__(self, value: _T, type_hint: Any = None) -> None: ...
53
def is_owner(self) -> bool: ...
54
def confirmed_by_owner(self) -> bool: ...
55
def owner(self) -> WorkerInfo: ...
56
def owner_name(self) -> str: ...
57
def to_here(self, timeout: float = ...) -> _T: ...
58
def local_value(self) -> Any: ...
59
def rpc_sync(self, timeout: float = ...) -> Any: ...
60
def rpc_async(self, timeout: float = ...) -> Any: ...
61
def remote(self, timeout: float = ...) -> Any: ...
62
def _serialize(self) -> tuple: ...
64
def _deserialize(tp: tuple) -> PyRRef: ...
65
def _get_type(self) -> type[_T]: ...
66
def _get_future(self) -> Future[_T]: ...
67
def _get_profiling_future(self) -> Future[_T]: ...
68
def _set_profiling_future(self, profilingFuture: Future[_T]): ...
70
class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions):
71
num_worker_threads: int
72
device_maps: dict[str, dict[torch.device, torch.device]]
73
devices: list[torch.device]
76
num_worker_threads: int,
77
_transports: list | None,
78
_channels: list | None,
79
rpc_timeout: float = ...,
80
init_method: str = ...,
81
device_maps: dict[str, dict[torch.device, torch.device]] = {},
82
devices: list[torch.device] = [],
87
device_map: dict[torch.device, torch.device],
90
class TensorPipeAgent(RpcAgent):
96
world_size: int | None,
97
opts: _TensorPipeRpcBackendOptionsBase,
98
reverse_device_maps: dict[str, dict[torch.device, torch.device]],
99
devices: list[torch.device],
101
def join(self, shutdown: bool = False, timeout: float = 0): ...
102
def shutdown(self): ...
104
def get_worker_info(self) -> WorkerInfo: ...
106
def get_worker_info(self, workerName: str) -> WorkerInfo: ...
108
def get_worker_info(self, id: int) -> WorkerInfo: ...
109
def get_worker_infos(self) -> list[WorkerInfo]: ...
110
def _get_device_map(self, dst: WorkerInfo) -> dict[torch.device, torch.device]: ...
111
def _update_group_membership(
113
worker_info: WorkerInfo,
114
my_devices: list[torch.device],
115
reverse_device_map: dict[str, dict[torch.device, torch.device]],
118
def _get_backend_options(self) -> _TensorPipeRpcBackendOptionsBase: ...
120
def is_static_group(self) -> bool: ...
122
def store(self) -> Store: ...
124
def _is_current_rpc_agent_set() -> bool: ...
125
def _get_current_rpc_agent() -> RpcAgent: ...
126
def _set_and_start_rpc_agent(agent: RpcAgent): ...
127
def _reset_current_rpc_agent(): ...
128
def _delete_all_user_and_unforked_owner_rrefs(timeout: timedelta = ...): ...
129
def _destroy_rref_context(ignoreRRefLeak: bool): ...
130
def _rref_context_get_debug_info() -> dict[str, str]: ...
131
def _cleanup_python_rpc_handler(): ...
132
def _invoke_rpc_builtin(
135
rpcTimeoutSeconds: float,
139
def _invoke_rpc_python_udf(
141
pickledPythonUDF: str,
142
tensors: list[torch.Tensor],
143
rpcTimeoutSeconds: float,
144
isAsyncExecution: bool,
146
def _invoke_rpc_torchscript(
148
qualifiedNameStr: str,
151
rpcTimeoutSeconds: float,
152
isAsyncExecution: bool,
154
def _invoke_remote_builtin(
157
rpcTimeoutSeconds: float,
161
def _invoke_remote_python_udf(
163
pickledPythonUDF: str,
164
tensors: list[torch.Tensor],
165
rpcTimeoutSeconds: float,
166
isAsyncExecution: bool,
168
def _invoke_remote_torchscript(
169
dstWorkerName: WorkerInfo,
170
qualifiedNameStr: str,
171
rpcTimeoutSeconds: float,
172
isAsyncExecution: bool,
176
def get_rpc_timeout() -> float: ...
177
def enable_gil_profiling(flag: bool): ...
178
def _set_rpc_timeout(rpcTimeoutSeconds: float): ...
180
class RemoteProfilerManager:
182
def set_current_profiling_key(key: str): ...
184
def _enable_server_process_global_profiler(new_config: ProfilerConfig): ...
185
def _disable_server_process_global_profiler() -> list[list[list[ProfilerEvent]]]: ...
186
def _set_profiler_node_id(default_node_id: int): ...
187
def _enable_jit_rref_pickle(): ...
188
def _disable_jit_rref_pickle(): ...