pytorch

Форк
0
/
_distributed_rpc.pyi 
188 строк · 5.9 Кб
1
# mypy: allow-untyped-defs
2
# mypy: disable-error-code="type-arg"
3
from datetime import timedelta
4
from typing import Any, Generic, overload, TypeVar
5

6
import torch
7
from torch._C import Future
8
from torch._C._autograd import ProfilerEvent
9
from torch._C._distributed_c10d import Store
10
from torch._C._profiler import ProfilerConfig
11

12
# This module is defined in torch/csrc/distributed/rpc/init.cpp
13

14
_DEFAULT_INIT_METHOD: str
15
_DEFAULT_NUM_WORKER_THREADS: int
16
_UNSET_RPC_TIMEOUT: float
17
_DEFAULT_RPC_TIMEOUT_SEC: float
18

19
_T = TypeVar("_T")
20

21
class RpcBackendOptions:
22
    rpc_timeout: float
23
    init_method: str
24
    def __init__(
25
        self,
26
        rpc_timeout: float = ...,
27
        init_method: str = ...,
28
    ) -> None: ...
29

30
class WorkerInfo:
31
    def __init__(self, name: str, worker_id: int) -> None: ...
32
    @property
33
    def name(self) -> str: ...
34
    @property
35
    def id(self) -> int: ...
36
    def __eq__(self, other: object) -> bool: ...
37

38
class RpcAgent:
39
    def join(self, shutdown: bool = False, timeout: float = 0): ...
40
    def sync(self): ...
41
    def shutdown(self): ...
42
    @overload
43
    def get_worker_info(self) -> WorkerInfo: ...
44
    @overload
45
    def get_worker_info(self, workerName: str) -> WorkerInfo: ...
46
    def get_worker_infos(self) -> list[WorkerInfo]: ...
47
    def _get_device_map(self, dst: WorkerInfo) -> dict[torch.device, torch.device]: ...
48
    def get_debug_info(self) -> dict[str, str]: ...
49
    def get_metrics(self) -> dict[str, str]: ...
50

51
class PyRRef(Generic[_T]):
52
    def __init__(self, value: _T, type_hint: Any = None) -> None: ...
53
    def is_owner(self) -> bool: ...
54
    def confirmed_by_owner(self) -> bool: ...
55
    def owner(self) -> WorkerInfo: ...
56
    def owner_name(self) -> str: ...
57
    def to_here(self, timeout: float = ...) -> _T: ...
58
    def local_value(self) -> Any: ...
59
    def rpc_sync(self, timeout: float = ...) -> Any: ...
60
    def rpc_async(self, timeout: float = ...) -> Any: ...
61
    def remote(self, timeout: float = ...) -> Any: ...
62
    def _serialize(self) -> tuple: ...
63
    @staticmethod
64
    def _deserialize(tp: tuple) -> PyRRef: ...
65
    def _get_type(self) -> type[_T]: ...
66
    def _get_future(self) -> Future[_T]: ...
67
    def _get_profiling_future(self) -> Future[_T]: ...
68
    def _set_profiling_future(self, profilingFuture: Future[_T]): ...
69

70
class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions):
71
    num_worker_threads: int
72
    device_maps: dict[str, dict[torch.device, torch.device]]
73
    devices: list[torch.device]
74
    def __init__(
75
        self,
76
        num_worker_threads: int,
77
        _transports: list | None,
78
        _channels: list | None,
79
        rpc_timeout: float = ...,
80
        init_method: str = ...,
81
        device_maps: dict[str, dict[torch.device, torch.device]] = {},  # noqa: B006
82
        devices: list[torch.device] = [],  # noqa: B006
83
    ) -> None: ...
84
    def _set_device_map(
85
        self,
86
        to: str,
87
        device_map: dict[torch.device, torch.device],
88
    ): ...
89

90
class TensorPipeAgent(RpcAgent):
91
    def __init__(
92
        self,
93
        store: Store,
94
        name: str,
95
        worker_id: int,
96
        world_size: int | None,
97
        opts: _TensorPipeRpcBackendOptionsBase,
98
        reverse_device_maps: dict[str, dict[torch.device, torch.device]],
99
        devices: list[torch.device],
100
    ) -> None: ...
101
    def join(self, shutdown: bool = False, timeout: float = 0): ...
102
    def shutdown(self): ...
103
    @overload
104
    def get_worker_info(self) -> WorkerInfo: ...
105
    @overload
106
    def get_worker_info(self, workerName: str) -> WorkerInfo: ...
107
    @overload
108
    def get_worker_info(self, id: int) -> WorkerInfo: ...
109
    def get_worker_infos(self) -> list[WorkerInfo]: ...
110
    def _get_device_map(self, dst: WorkerInfo) -> dict[torch.device, torch.device]: ...
111
    def _update_group_membership(
112
        self,
113
        worker_info: WorkerInfo,
114
        my_devices: list[torch.device],
115
        reverse_device_map: dict[str, dict[torch.device, torch.device]],
116
        is_join: bool,
117
    ): ...
118
    def _get_backend_options(self) -> _TensorPipeRpcBackendOptionsBase: ...
119
    @property
120
    def is_static_group(self) -> bool: ...
121
    @property
122
    def store(self) -> Store: ...
123

124
def _is_current_rpc_agent_set() -> bool: ...
125
def _get_current_rpc_agent() -> RpcAgent: ...
126
def _set_and_start_rpc_agent(agent: RpcAgent): ...
127
def _reset_current_rpc_agent(): ...
128
def _delete_all_user_and_unforked_owner_rrefs(timeout: timedelta = ...): ...
129
def _destroy_rref_context(ignoreRRefLeak: bool): ...
130
def _rref_context_get_debug_info() -> dict[str, str]: ...
131
def _cleanup_python_rpc_handler(): ...
132
def _invoke_rpc_builtin(
133
    dst: WorkerInfo,
134
    opName: str,
135
    rpcTimeoutSeconds: float,
136
    *args: Any,
137
    **kwargs: Any,
138
): ...
139
def _invoke_rpc_python_udf(
140
    dst: WorkerInfo,
141
    pickledPythonUDF: str,
142
    tensors: list[torch.Tensor],
143
    rpcTimeoutSeconds: float,
144
    isAsyncExecution: bool,
145
): ...
146
def _invoke_rpc_torchscript(
147
    dstWorkerName: str,
148
    qualifiedNameStr: str,
149
    argsTuple: tuple,
150
    kwargsDict: dict,
151
    rpcTimeoutSeconds: float,
152
    isAsyncExecution: bool,
153
): ...
154
def _invoke_remote_builtin(
155
    dst: WorkerInfo,
156
    opName: str,
157
    rpcTimeoutSeconds: float,
158
    *args: Any,
159
    **kwargs: Any,
160
): ...
161
def _invoke_remote_python_udf(
162
    dst: WorkerInfo,
163
    pickledPythonUDF: str,
164
    tensors: list[torch.Tensor],
165
    rpcTimeoutSeconds: float,
166
    isAsyncExecution: bool,
167
): ...
168
def _invoke_remote_torchscript(
169
    dstWorkerName: WorkerInfo,
170
    qualifiedNameStr: str,
171
    rpcTimeoutSeconds: float,
172
    isAsyncExecution: bool,
173
    *args: Any,
174
    **kwargs: Any,
175
): ...
176
def get_rpc_timeout() -> float: ...
177
def enable_gil_profiling(flag: bool): ...
178
def _set_rpc_timeout(rpcTimeoutSeconds: float): ...
179

180
class RemoteProfilerManager:
181
    @staticmethod
182
    def set_current_profiling_key(key: str): ...
183

184
def _enable_server_process_global_profiler(new_config: ProfilerConfig): ...
185
def _disable_server_process_global_profiler() -> list[list[list[ProfilerEvent]]]: ...
186
def _set_profiler_node_id(default_node_id: int): ...
187
def _enable_jit_rref_pickle(): ...
188
def _disable_jit_rref_pickle(): ...
189

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.