pytorch

Форк
0
/
__init__.pyi.in 
2439 строк · 87.0 Кб
1
# ${generated_comment}
2
# mypy: disable-error-code="type-arg"
3
# mypy: allow-untyped-defs
4

5
import builtins
6
from enum import Enum, IntEnum
7
from pathlib import Path
8
from typing import (
9
    Any,
10
    AnyStr,
11
    BinaryIO,
12
    Callable,
13
    ContextManager,
14
    Dict,
15
    Generic,
16
    Iterable,
17
    Iterator,
18
    List,
19
    Literal,
20
    NamedTuple,
21
    Optional,
22
    Protocol,
23
    Sequence,
24
    Set,
25
    SupportsIndex,
26
    Tuple,
27
    Type as _Type,
28
    TypeVar,
29
    Union,
30
    overload,
31
    runtime_checkable,
32
)
33
from typing_extensions import ParamSpec, Self
34

35
import numpy
36

37
import torch
38
from torch import SymInt, Tensor, inf
39
from torch._prims_common import DeviceLikeType
40
from torch.autograd.graph import Node as _Node
41
from torch.fx.node import Node as FxNode
42
from torch.package import PackageExporter
43
from torch.storage import TypedStorage, UntypedStorage
44
from torch.types import (
45
    Device,
46
    Number,
47
    Storage,
48
    IntLikeType,
49
    _bool,
50
    _bytes,
51
    _complex,
52
    _device,
53
    _dispatchkey,
54
    _dtype,
55
    _float,
56
    _int,
57
    _layout,
58
    _qscheme,
59
    _size,
60
    _str,
61
    _symsize,
62
)
63
from torch.utils._python_dispatch import TorchDispatchMode
64

65
from . import (
66
    _aoti,
67
    _cpu,
68
    _dynamo,
69
    _functorch,
70
    _lazy,
71
    _lazy_ts_backend,
72
    _nn,
73
    _onnx,
74
    _VariableFunctions,
75
    _verbose,
76
)
77

78
# This module is defined in torch/csrc/Module.cpp
79

80
K = TypeVar("K")
81
T = TypeVar("T")
82
S = TypeVar("S", bound="torch.Tensor")
83
P = ParamSpec("P")
84
ReturnVal = TypeVar("ReturnVal", covariant=True)  # return value (always covariant)
85
_T_co = TypeVar("_T_co", covariant=True)
86

87

88
@runtime_checkable
89
class _NestedSequence(Protocol[_T_co]):
90
    """A protocol for representing nested sequences.
91

92
    References::
93
        `numpy._typing._NestedSequence`
94
        <https://github.com/numpy/numpy/blob/main/numpy/_typing/_nested_sequence.py>
95
    """
96

97
    def __len__(self, /) -> builtins.int: ...
98
    def __getitem__(self, index: builtins.int, /) -> _T_co | _NestedSequence[_T_co]: ...
99
    def __contains__(self, x: builtins.object, /) -> builtins.bool: ...
100
    def __iter__(self, /) -> Iterator[_T_co | _NestedSequence[_T_co]]: ...
101
    def __reversed__(self, /) -> Iterator[_T_co | _NestedSequence[_T_co]]: ...
102
    def count(self, value: Any, /) -> builtins.int: ...
103
    def index(self, value: Any, /) -> builtins.int: ...
104

105

106
# Defined in torch/csrc/Device.cpp
107
class device:
108
    type: str  # THPDevice_type
109
    index: _int  # THPDevice_index
110

111
    def __get__(self, instance, owner=None) -> device: ...
112

113
    # THPDevice_pynew
114
    @overload
115
    def __init__(self, device: DeviceLikeType) -> None: ...
116
    @overload
117
    def __init__(self, type: str, index: _int) -> None: ...
118

119
    # Uncomment if we ever make torch.device a decorator
120
    # def __call__(self, func: T) -> T: ...
121

122
    def __enter__(self) -> device: ...
123
    def __exit__(self, exc_type, exc_val, exc_tb) -> None: ...
124
    def __reduce__(self) -> Tuple[Any, ...]: ...  # THPDevice_reduce
125

126
# Defined in torch/csrc/Stream.cpp
127
class Stream:
128
    stream_id: _int  # Stream id
129
    device_index: _int
130
    device_type: _int
131

132
    device: _device  # The device of the stream
133

134
    @overload
135
    def __new__(self, device: Optional[DeviceLikeType] = None, *, priority: _int = 0) -> Stream: ...
136
    @overload
137
    def __new__(self, stream_id: _int, device_index: _int, device_type: _int, *, priority: _int = 0) -> Stream: ...
138
    def query(self) -> _bool: ...
139
    def synchronize(self) -> None: ...
140
    def wait_event(self, event: Event) -> None: ...
141
    def wait_stream(self, other: Stream) -> None: ...
142
    def record_event(self, event: Optional[Event] = None) -> Event: ...
143
    def __hash__(self) -> _int: ...
144
    def __repr__(self) -> str: ...
145
    def __eq__(self, other: object) -> _bool: ...
146

147

148
# Defined in torch/csrc/Event.cpp
149
class Event:
150

151
    device: _device  # The device of the Event
152
    event_id: _int # The raw event created by device backend
153

154
    def __new__(self,
155
        device: Optional[DeviceLikeType] = None,
156
        *,
157
        enable_timing: _bool = False,
158
        blocking: _bool = False,
159
        interprocess: _bool = False) -> Event: ...
160
    @classmethod
161
    def from_ipc_handle(self, device: _device, ipc_handle: bytes) -> Event: ...
162
    def record(self, stream: Optional[Stream] = None) -> None: ...
163
    def wait(self, stream: Optional[Stream] = None) -> None: ...
164
    def query(self) -> _bool: ...
165
    def elapsed_time(self, other: Event) -> _float: ...
166
    def synchronize(self) -> None: ...
167
    def ipc_handle(self) -> bytes: ...
168
    def __repr__(self) -> str: ...
169

170

171
# Defined in torch/csrc/Size.cpp
172
class Size(Tuple[_int, ...]):
173
    # TODO: __reduce__
174

175
    @overload  # type: ignore[override]
176
    def __getitem__(self: Size, key: _int) -> _int: ...
177
    @overload
178
    def __getitem__(self: Size, key: slice) -> Size: ...
179
    def numel(self: Size) -> _int: ...
180

181
# Defined in torch/csrc/Dtype.cpp
182
class dtype:
183
    # TODO: __reduce__
184
    is_floating_point: _bool
185
    is_complex: _bool
186
    is_signed: _bool
187
    itemsize: _int
188
    def to_real(self) -> dtype: ...
189
    def to_complex(self) -> dtype: ...
190

191
# Defined in torch/csrc/TypeInfo.cpp
192
class iinfo:
193
    bits: _int
194
    min: _int
195
    max: _int
196
    dtype: str
197

198
    def __init__(self, dtype: _dtype) -> None: ...
199

200
class finfo:
201
    bits: _int
202
    min: _float
203
    max: _float
204
    eps: _float
205
    tiny: _float
206
    smallest_normal: _float
207
    resolution: _float
208
    dtype: str
209

210
    @overload
211
    def __init__(self, dtype: _dtype) -> None: ...
212
    @overload
213
    def __init__(self) -> None: ...
214

215
${dtype_class_hints}
216

217
# Defined in torch/csrc/Layout.cpp
218
class layout: ...
219

220
# Defined in torch/csrc/utils/disable_torch_function.cpp
221
def DisableTorchFunction(): ...
222
def DisableTorchFunctionSubclass(): ...
223

224
# Defined in torch/csrc/utils/tensor_layouts.cpp
225
strided: layout = ...
226
sparse_coo: layout = ...
227
sparse_csr: layout = ...
228
sparse_csc: layout = ...
229
sparse_bsr: layout = ...
230
sparse_bsc: layout = ...
231
_mkldnn: layout = ...
232
jagged: layout = ...
233

234
# Defined in torch/csrc/MemoryFormat.cpp
235
class memory_format: ...
236

237
# Defined in torch/csrc/utils/tensor_memoryformats.cpp
238
contiguous_format: memory_format = ...
239
channels_last: memory_format = ...
240
channels_last_3d: memory_format = ...
241
preserve_format: memory_format = ...
242

243
# Defined in torch/csrc/QScheme.cpp
244
class qscheme: ...
245

246
# Defined in torch/csrc/utils/tensor_qschemes.h
247
per_tensor_affine: qscheme = ...
248
per_channel_affine: qscheme = ...
249
per_tensor_symmetric: qscheme = ...
250
per_channel_symmetric: qscheme = ...
251
per_channel_affine_float_qparams: qscheme = ...
252

253
# Defined in torch/csrc/autograd/python_function.cpp
254
class _FunctionBase:
255
    saved_tensors: Tuple[Tensor]
256
    _raw_saved_tensors: Tuple[Any]
257
    next_functions: Tuple[Tuple[Any, _int], ...]
258
    needs_input_grad: Tuple[_bool]
259
    metadata: dict
260
    _materialize_non_diff_grads: _bool
261
    # skip adding type hints for the fields that have wrappers defined
262
    # in torch/autograd/function.py
263

264
# Defined in torch/csrc/autograd/python_legacy_variable.cpp
265
class _LegacyVariableBase(Tensor):  # inherits from Tensor to appease mypy
266
    def __init__(
267
        self,
268
        data: Optional[Tensor] = ...,
269
        requires_grad: Optional[_bool] = ...,
270
        volatile: Optional[_bool] = ...,
271
        _grad_fn: Optional[_FunctionBase] = ...,
272
    ) -> None: ...
273

274
# Defined in torch/csrc/jit/python/init.cpp
275
class IODescriptor: ...
276
class JITException: ...
277

278
class Future(Generic[T]):
279
    def __init__(self, devices: List[device]) -> None: ...
280
    def done(self) -> _bool: ...
281
    def value(self) -> T: ...
282
    def wait(self) -> T: ...
283
    def add_done_callback(self, callback: Callable) -> None: ...
284
    def then(self, callback: Callable) -> Future[T]: ...
285
    def set_result(self, result: T) -> None: ...
286
    def _set_unwrap_func(self, callback: Callable) -> None: ...
287

288
class _Await:
289
    def __init__(self) -> None: ...
290
    def fn(self) -> Callable: ...
291
    def args(self) -> Tuple[Any, ...]: ...
292
    def is_nowait(self) -> _bool: ...
293

294
def _jit_set_num_profiled_runs(num: _size) -> _size: ...
295

296
# Defined in torch/csrc/jit/passes/mobile_optimizer_type.h
297
class _MobileOptimizerType: ...
298

299
CONV_BN_FUSION: _MobileOptimizerType
300
INSERT_FOLD_PREPACK_OPS: _MobileOptimizerType
301
REMOVE_DROPOUT: _MobileOptimizerType
302
FUSE_ADD_RELU: _MobileOptimizerType
303
HOIST_CONV_PACKED_PARAMS: _MobileOptimizerType
304
VULKAN_AUTOMATIC_GPU_TRANSFER: _MobileOptimizerType
305

306
def fork(*args: Any, **kwargs: Any) -> Future: ...
307
def wait(fut: Future) -> Any: ...
308
def _awaitable(*args: Any, **kwargs: Any) -> _Await: ...
309
def _awaitable_wait(aw: _Await) -> Any: ...
310
def _awaitable_nowait(x: Any) -> _Await: ...
311
def _collect_all(futures: List[Future]) -> Future: ...
312
def _set_print_stack_traces_on_fatal_signal(print: _bool) -> None: ...
313
def unify_type_list(types: List[JitType]) -> JitType: ...
314
def _freeze_module(
315
    module: ScriptModule,
316
    preserved_attrs: List[str] = [],
317
    freeze_interfaces: _bool = True,
318
    preserveParameters: _bool = True,
319
) -> ScriptModule: ...
320
def _jit_pass_optimize_frozen_graph(Graph, optimize_numerics: _bool = True) -> None: ...
321
def _jit_pass_optimize_for_inference(
322
    module: torch.jit.ScriptModule,
323
    other_methods: List[str] = [],
324
) -> None: ...
325
def _jit_pass_fold_frozen_conv_bn(graph: Graph): ...
326
def _jit_pass_fold_frozen_conv_add_or_sub(graph: Graph): ...
327
def _jit_pass_fold_frozen_conv_mul_or_div(graph: Graph): ...
328
def _jit_pass_fuse_frozen_conv_add_relu(graph: Graph): ...
329
def _jit_pass_concat_frozen_linear(graph: Graph): ...
330
def _jit_pass_convert_frozen_ops_to_mkldnn(graph: Graph): ...
331
def _jit_pass_transpose_frozen_linear(graph: Graph): ...
332
def _jit_pass_remove_dropout(module: torch.jit.ScriptModule): ...
333
def _is_tracing() -> _bool: ...
334
def _jit_init() -> _bool: ...
335
def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ...
336
def _jit_unflatten(vars: List[Tensor], desc: IODescriptor) -> Any: ...
337
def _jit_get_operation(op_name: str) -> Tuple[Callable, List[str]]: ...
338
def _get_operation_overload(
339
    op_name: str,
340
    op_overload_name: str,
341
) -> Tuple[Callable, Callable, List[Any]]: ...
342
def _get_schema(op_name: str, overload_name: str) -> FunctionSchema: ...
343
def _jit_pass_optimize_for_mobile(
344
    module: torch.jit.ScriptModule,
345
    optimization_blocklist: Set[_MobileOptimizerType],
346
    preserved_methods: List[AnyStr],
347
) -> torch.jit.ScriptModule: ...
348
def _clone_module_with_class(
349
    module: torch.jit.ScriptModule,
350
    ignored_methods: List[AnyStr],
351
    ignored_attributes: List[AnyStr],
352
) -> torch.jit.ScriptModule: ...
353
def _jit_pass_vulkan_optimize_for_mobile(
354
    module: torch.jit.ScriptModule,
355
    optimization_blocklist: Set[_MobileOptimizerType],
356
    preserved_methods: List[AnyStr],
357
) -> torch.jit.ScriptModule: ...
358
def _jit_pass_metal_optimize_for_mobile(
359
    module: torch.jit.ScriptModule,
360
    preserved_methods: List[AnyStr],
361
) -> torch.jit.ScriptModule: ...
362
def _jit_pass_inline(Graph) -> None: ...
363
def _jit_pass_constant_propagation(Graph) -> None: ...
364
def _jit_pass_propagate_shapes_on_graph(Graph) -> None: ...
365
def _jit_register_decomposition_for_schema(schema: FunctionSchema, Graph) -> None: ...
366
def _jit_erase_non_input_shape_information(Graph) -> None: ...
367
def _jit_get_schemas_for_operator(name: str) -> List[FunctionSchema]: ...
368
def _jit_get_all_schemas() -> List[FunctionSchema]: ...
369
def _jit_check_alias_annotation(
370
    g: Graph,
371
    args: Tuple[Any, ...],
372
    unqualified_op_name: str,
373
): ...
374
def _jit_can_fuse_on_cpu() -> _bool: ...
375
def _jit_can_fuse_on_gpu() -> _bool: ...
376
def _jit_can_fuse_on_cpu_legacy() -> _bool: ...
377
def _debug_get_fusion_group_inlining() -> _bool: ...
378
def _debug_set_fusion_group_inlining(enable: _bool): ...
379
def _jit_texpr_fuser_enabled() -> _bool: ...
380
def _jit_nvfuser_enabled() -> _bool: ...
381
def _jit_llga_enabled() -> _bool: ...
382
def _jit_set_llga_enabled(enable: _bool): ...
383
def _llvm_enabled() -> _bool: ...
384
def _jit_override_can_fuse_on_cpu(override: _bool): ...
385
def _jit_override_can_fuse_on_gpu(override: _bool): ...
386
def _jit_override_can_fuse_on_cpu_legacy(override: _bool): ...
387
def _jit_set_symbolic_shapes_test_mode(override: _bool): ...
388
def _jit_symbolic_shapes_test_mode_enabled() -> _bool: ...
389
def _jit_set_texpr_fuser_enabled(enable: _bool): ...
390
def _jit_set_te_must_use_llvm_cpu(use_llvm: _bool): ...
391
def _jit_set_nvfuser_enabled(enable: _bool) -> _bool: ...
392
def _jit_cat_wo_conditionals(optimize_cat: _bool): ...
393
def _jit_opt_conditionals(opt_conds: _bool): ...
394
def _jit_pass_canonicalize(graph: Graph, keep_unique_names: _bool = True): ...
395
def _jit_pass_erase_shape_information(graph: Graph): ...
396
def _jit_pass_fold_convbn(module: torch.jit.ScriptModule): ...
397
def _jit_pass_insert_observers(
398
    module: torch.jit.ScriptModule,
399
    method_name: str,
400
    qconfig_dict: Dict[str, Any],
401
    inplace: _bool,
402
    quant_type: _int,
403
): ...
404
def _jit_pass_insert_quant_dequant(
405
    module: torch.jit.ScriptModule,
406
    method_name: str,
407
    inplace: _bool,
408
    debug: _bool,
409
    quant_type: _int,
410
): ...
411
def _jit_pass_insert_quant_dequant_for_ondevice_ptq(
412
    module: torch.jit.ScriptModule,
413
    method_name: str,
414
    inplace: _bool,
415
    debug: _bool,
416
    quant_type: _int,
417
): ...
418
def _jit_pass_quant_finalize(
419
    module: torch.jit.ScriptModule,
420
    quant_type: _int,
421
    preserved_attrs: Sequence[str],
422
): ...
423
def _jit_pass_quant_finalize_for_ondevice_ptq(
424
    module: torch.jit.ScriptModule,
425
    quant_type: _int,
426
    method_name: str,
427
): ...
428
def _jit_pass_insert_observer_method_for_ondevice_ptq(
429
    module: torch.jit.ScriptModule,
430
    method_name: str,
431
    qconfig_dict: Dict[str, Any],
432
    inplace: _bool,
433
    quant_type: _int,
434
): ...
435
def _jit_set_profiling_executor(profiling_flag: _bool) -> _bool: ...
436
def _jit_set_profiling_mode(profiling_flag: _bool) -> _bool: ...
437
def _jit_set_fusion_strategy(
438
    strategy: List[Tuple[str, _int]],
439
) -> List[Tuple[str, _int]]: ...
440
def _jit_try_infer_type(obj: Any) -> InferredType: ...
441
def _jit_get_trigger_value(trigger_name: str) -> _int: ...
442

443
# Defined in torch/csrc/jit/python/script_init.cpp
444
ResolutionCallback = Callable[[str], Callable[..., Any]]
445

446
# Defined in torch/csrc/jit/python/script_init.cpp
447
#        and torch/csrc/jit/python/init.cpp
448
def _maybe_call_torch_function_for_op_packet(
449
    op_overload_packet: Any,
450
    args: Any,
451
    kwargs: Any,
452
) -> Any: ...
453
def _check_schema_allow_fake_script_object(
454
    schema: FunctionSchema,
455
    args: Any,
456
    kwargs: Any,
457
) -> _bool: ...
458
def _create_function_from_graph(qualname: str, graph: Graph) -> ScriptFunction: ...
459
def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ...
460
def _ivalue_tags_match(lhs: ScriptModule, rhs: ScriptModule) -> _bool: ...
461
def _jit_assert_is_instance(obj: Any, type: JitType): ...
462
def _jit_clear_class_registry() -> None: ...
463
def _jit_set_emit_hooks(
464
    ModuleHook: Optional[Callable],
465
    FunctionHook: Optional[Callable],
466
) -> None: ...
467
def _jit_get_emit_hooks() -> Tuple[Callable, Callable]: ...
468
def _load_for_lite_interpreter(
469
    filename: Union[str, Path],
470
    map_location: Optional[DeviceLikeType],
471
): ...
472
def _load_for_lite_interpreter_from_buffer(
473
    buffer: BinaryIO,
474
    map_location: Optional[DeviceLikeType],
475
): ...
476
def _export_operator_list(module: LiteScriptModule): ...
477
def _quantize_ondevice_ptq_dynamic(module: LiteScriptModule, method_name: str): ...
478
def _get_model_bytecode_version(filename: Union[str, Path]) -> _int: ...
479
def _get_model_bytecode_version_from_buffer(buffer: BinaryIO) -> _int: ...
480
def _backport_for_mobile(
481
    filename_input: Union[str, Path],
482
    filename_output: Union[str, Path],
483
    to_version: _int,
484
) -> None: ...
485
def _backport_for_mobile_from_buffer(
486
    buffer: BinaryIO,
487
    filename_output: Union[str, Path],
488
    to_version: _int,
489
) -> None: ...
490
def _backport_for_mobile_to_buffer(
491
    filename_input: Union[str, Path],
492
    to_version: _int,
493
) -> bytes: ...
494
def _backport_for_mobile_from_buffer_to_buffer(
495
    buffer: BinaryIO,
496
    to_version: _int,
497
) -> bytes: ...
498
def _get_model_ops_and_info(filename: Union[str, Path]): ...
499
def _get_model_ops_and_info_from_buffer(buffer: BinaryIO): ...
500
def _get_mobile_model_contained_types(filename: Union[str, Path]): ...
501
def _get_mobile_model_contained_types_from_buffer(buffer: BinaryIO): ...
502
def _logging_set_logger(logger: LoggerBase) -> LoggerBase: ...
503
def _get_graph_executor_optimize(optimize: Optional[_bool] = None) -> _bool: ...
504
def _set_graph_executor_optimize(optimize: _bool): ...
505
def _export_opnames(module: ScriptModule) -> List[str]: ...
506
def _create_function_from_trace(
507
    qualname: str,
508
    func: Callable[..., Any],
509
    input_tuple: Tuple[Any, ...],
510
    var_lookup_fn: Callable[[Tensor], str],
511
    strict: _bool,
512
    force_outplace: _bool,
513
    argument_names: List[str],
514
) -> Tuple[Graph, Stack]: ...
515
def _create_function_from_trace_with_dict(
516
    qualname: str,
517
    func: Callable[..., Any],
518
    input_dict: Dict[str, Any],
519
    var_lookup_fn: Callable[[Tensor], str],
520
    strict: _bool,
521
    force_outplace: _bool,
522
    argument_names: List[str],
523
) -> Tuple[Graph, Stack]: ...
524
def _jit_is_script_object(obj: Any) -> _bool: ...
525
def _last_executed_optimized_graph() -> Graph: ...
526
def parse_type_comment(comment: str) -> Decl: ...
527
def _get_upgraders_map_size() -> _int: ...
528
def _get_upgraders_entry_map() -> Dict[str, str]: ...
529
def _dump_upgraders_map() -> Dict[str, str]: ...
530
def _test_only_populate_upgraders(content: Dict[str, str]) -> None: ...
531
def _test_only_remove_upgraders(content: Dict[str, str]) -> None: ...
532
def merge_type_from_type_comment(
533
    decl: Decl,
534
    type_annotation_decl: Decl,
535
    is_method: _bool,
536
) -> Decl: ...
537
def parse_ir(input: str, parse_tensor_constants: _bool = False) -> Graph: ...
538
def parse_schema(schema: str) -> FunctionSchema: ...
539
def get_device(input: Tensor) -> _int: ...
540
def _resolve_type_from_object(
541
    obj: Any,
542
    range: SourceRange,
543
    rcb: ResolutionCallback,
544
) -> JitType: ...
545
def _create_module_with_type(ty: JitType) -> ScriptModule: ...
546
def _create_object_with_type(ty: ClassType) -> ScriptObject: ...
547
def _run_emit_module_hook(m: ScriptModule): ...
548
def _replace_overloaded_method_decl(
549
    overload_decl: Decl,
550
    implementation_def: Def,
551
    new_name: str,
552
) -> Def: ...
553
def _jit_pass_lower_all_tuples(graph: Graph) -> None: ...
554
def _jit_pass_onnx_set_dynamic_input_shape(
555
    graph: Graph,
556
    dynamic_axes: Dict[str, Dict[_int, str]],
557
    input_names: List[str],
558
) -> None: ...
559
def _jit_pass_onnx_graph_shape_type_inference(
560
    graph: Graph,
561
    params_dict: Dict[str, IValue],
562
    opset_version: _int,
563
) -> None: ...
564
def _jit_pass_onnx_assign_output_shape(
565
    graph: Graph,
566
    tensors: List[Tensor],
567
    desc: IODescriptor,
568
    onnx_shape_inference: _bool,
569
    is_script: _bool,
570
    opset_version: _int,
571
) -> None: ...
572
def _jit_pass_onnx_remove_inplace_ops_for_onnx(
573
    graph: Graph,
574
    module: Optional[ScriptModule] = None,
575
) -> None: ...
576
def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ...
577
def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ...
578
def _jit_pass_peephole(
579
    graph: Graph,
580
    disable_shape_peepholes: _bool = False,
581
) -> None: ...
582
def _jit_pass_onnx_autograd_function_process(graph: Graph) -> None: ...
583
def _jit_pass_fuse_addmm(graph: Graph) -> None: ...
584
def _jit_pass_onnx_preprocess(graph: Graph) -> None: ...
585
def _jit_pass_prepare_division_for_onnx(graph: Graph) -> None: ...
586
def _jit_pass_onnx_remove_print(graph: Graph) -> None: ...
587
def _jit_pass_onnx_preprocess_caffe2(graph: Graph) -> None: ...
588
def _jit_pass_onnx_unpack_quantized_weights(
589
    graph: Graph,
590
    paramsDict: Dict[str, IValue],
591
) -> Dict[str, IValue]: ...
592
def _jit_pass_onnx_quantization_insert_permutes(
593
    graph: Graph,
594
    paramsDict: Dict[str, IValue],
595
) -> Dict[str, IValue]: ...
596
def _jit_pass_custom_pattern_based_rewrite_graph(
597
    pattern: str,
598
    fused_node_name: str,
599
    graph: Graph,
600
) -> None: ...
601
def _jit_onnx_list_model_parameters(
602
    module: ScriptModule,
603
) -> Tuple[ScriptModule, List[IValue]]: ...
604
def _jit_pass_erase_number_types(graph: Graph) -> None: ...
605
def _jit_pass_onnx_lint(graph: Graph) -> None: ...
606
def _jit_pass_onnx(
607
    graph: Graph,
608
    _jit_pass_onnx: _onnx.OperatorExportTypes,
609
) -> Graph: ...
610
def _jit_pass_onnx_scalar_type_analysis(
611
    graph: Graph,
612
    lowprecision_cast: _bool,
613
    opset_version: _int,
614
) -> None: ...
615
def _jit_pass_onnx_peephole(
616
    graph: Graph,
617
    opset_version: _int,
618
    fixed_batch_size: _bool,
619
) -> None: ...
620
def _jit_pass_dce_allow_deleting_nodes_with_side_effects(graph: Graph) -> None: ...
621
def _jit_pass_onnx_function_substitution(graph: Graph) -> None: ...
622
def _jit_pass_onnx_function_extraction(
623
    graph: Graph,
624
    module_names: Set[str],
625
    param_names: List[str],
626
) -> Dict[Node, Dict[str, str]]: ...
627
def _jit_pass_onnx_clear_scope_records() -> None: ...
628
def _jit_pass_onnx_track_scope_attributes(
629
    graph: Graph,
630
    onnx_attrs: Dict[str, Any],
631
) -> None: ...
632
def _jit_is_onnx_log_enabled() -> _bool: ...
633
def _jit_set_onnx_log_enabled(enabled: _bool) -> None: ...
634
def _jit_set_onnx_log_output_stream(stream_name: str) -> None: ...
635
def _jit_onnx_log(*args: Any) -> None: ...
636
def _jit_pass_lower_graph(graph: Graph, m: Module) -> Tuple[Graph, List[IValue]]: ...
637
def _jit_pass_inline_fork_wait(graph: Graph) -> None: ...
638
def _jit_pass_onnx_deduplicate_initializers(
639
    graph: Graph,
640
    params_dict: Dict[str, IValue],
641
    is_train: _bool,
642
) -> Dict[str, IValue]: ...
643
def _jit_pass_onnx_eval_peephole(
644
    graph: Graph,
645
    paramsDict: Dict[str, IValue],
646
) -> Dict[str, IValue]: ...
647
def _jit_pass_onnx_constant_fold(
648
    graph: Graph,
649
    paramsDict: Dict[str, IValue],
650
    opset_version: _int,
651
) -> Dict[str, IValue]: ...
652
def _jit_pass_onnx_eliminate_unused_items(
653
    graph: Graph,
654
    paramsDict: Dict[str, IValue],
655
) -> Dict[str, IValue]: ...
656
def _jit_pass_onnx_cast_all_constant_to_floating(graph: Graph) -> None: ...
657
def _jit_pass_filter_non_tensor_arguments(
658
    params: Dict[str, IValue],
659
) -> Dict[str, Tensor]: ...
660
def _jit_decay_packed_param_input_types(graph: Graph) -> None: ...
661
def _jit_pass_onnx_node_shape_type_inference(
662
    n: Node,
663
    paramsDict: Dict[str, IValue],
664
    opset_version: _int,
665
) -> None: ...
666
def _jit_onnx_convert_pattern_from_subblock(
667
    block: Block,
668
    n: Node,
669
    env: Dict[Value, Value],
670
    values_in_env: Set[Value],
671
) -> List[Value]: ...
672
def _jit_pass_onnx_block(
673
    old_block: Block,
674
    new_block: Block,
675
    operator_export_type: _onnx.OperatorExportTypes,
676
    env: Dict[Value, Value],
677
    values_in_env: Set[Value],
678
    is_sub_block: _bool,
679
) -> Dict[Value, Value]: ...
680
def _jit_pass_onnx_assign_scoped_names_for_node_and_value(graph: Graph) -> None: ...
681
def _jit_pass_fixup_onnx_controlflow_node(
682
    n: Node,
683
    opset_version: _int,
684
) -> List[Value]: ...
685
def _jit_onnx_create_full_scope_name(class_name: str, variable_name: str) -> str: ...
686
def _compile_graph_to_code_table(name: str, graph: Graph) -> IValue: ...
687
def _generate_upgraders_graph() -> Dict[str, Graph]: ...
688
def _calculate_package_version_based_on_upgraders(val: _bool): ...
689
def _get_version_calculator_flag() -> _bool: ...
690
def _jit_script_interface_compile(
691
    name: str,
692
    class_def: ClassDef,
693
    rcb: ResolutionCallback,
694
    is_module: _bool,
695
): ...
696
def _jit_script_compile_overload(
697
    qualname: str,
698
    overload_decl: Decl,
699
    implementation_def: Def,
700
    rcb: ResolutionCallback,
701
    implementation_defaults: Dict[str, Any],
702
    signature: Any,
703
): ...
704
def _jit_script_compile(
705
    qual_name: str,
706
    definition: Def,
707
    rcb: ResolutionCallback,
708
    defaults: Dict[str, Any],
709
): ...
710
def _jit_script_class_compile(
711
    qual_name: str,
712
    definition: ClassDef,
713
    defaults: Dict[str, Dict[str, Any]],
714
    rcb: ResolutionCallback,
715
): ...
716
def _parse_source_def(src: str) -> Def: ...
717
def import_ir_module(
718
    cu: CompilationUnit,
719
    filename: Union[str, Path],
720
    map_location: Optional[DeviceLikeType],
721
    extra_files: Dict[str, Any],
722
) -> ScriptModule: ...
723
def import_ir_module_from_buffer(
724
    cu: CompilationUnit,
725
    buffer: BinaryIO,
726
    map_location: Optional[DeviceLikeType],
727
    extra_files: Dict[str, Any],
728
) -> ScriptModule: ...
729
def _import_ir_module_from_package(
730
    cu: CompilationUnit,
731
    reader: PyTorchFileReader,
732
    storage_context: DeserializationStorageContext,
733
    map_location: Optional[DeviceLikeType],
734
    ts_id: str,
735
) -> ScriptModule: ...
736
def _assign_output_shapes(graph: Graph, inputs: List[Tensor]) -> Graph: ...
737
def _check_onnx_proto(proto: str) -> None: ...
738
def _propagate_and_assign_input_shapes(
739
    graph: Graph,
740
    inputs: Tuple[Tensor, ...],
741
    param_count_list: List[_int],
742
    with_grad: _bool,
743
    propagate: _bool,
744
) -> Graph: ...
745

746
# Defined in torch/csrc/jit/runtime/graph_executor.h
747
class GraphExecutorState: ...
748

749
# Defined in torch/torch/csrc/jit/ir/alias_analysis.h
750
class AliasDb:
751
    def __str__(self) -> str: ...
752

753
class _InsertPoint:
754
    def __enter__(self) -> None: ...
755
    def __exit__(self, *args) -> None: ...
756

757
# Defined in torch/csrc/jit/ir/ir.h
758
class Use:
759
    @property
760
    def user(self) -> Node: ...
761
    @property
762
    def offset(self) -> _int: ...
763
    def isAfter(self, other: Use) -> _bool: ...
764

765
# Defined in torch/csrc/jit/ir/ir.h
766
class Value:
767
    def type(self) -> JitType: ...
768
    def setType(self, t: JitType) -> Value: ...
769
    def setTypeAs(self, other: Value) -> Value: ...
770
    def inferTypeFrom(self, t: Tensor) -> None: ...
771
    def debugName(self) -> str: ...
772
    def setDebugName(self, name: str) -> None: ...
773
    def unique(self) -> _int: ...
774
    def offset(self) -> _int: ...
775
    def node(self) -> Node: ...
776
    def uses(self) -> List[Use]: ...
777
    def replaceAllUsesWith(self, val: Value) -> None: ...
778
    def replaceAllUsesAfterNodeWith(self, node: Node, val: Value) -> None: ...
779
    def requires_grad(self) -> _bool: ...
780
    def requiresGrad(self) -> _bool: ...
781
    def copyMetadata(self, other: Value) -> Value: ...
782
    def isCompleteTensor(self) -> _bool: ...
783
    def toIValue(self) -> IValue: ...
784

785
# Defined in torch/csrc/jit/ir/ir.h
786
class Block:
787
    def inputs(self) -> Iterator[Value]: ...
788
    def outputs(self) -> Iterator[Value]: ...
789
    def nodes(self) -> Iterator[Node]: ...
790
    def paramNode(self) -> Node: ...
791
    def returnNode(self) -> Node: ...
792
    def owningNode(self) -> Node: ...
793
    def registerOutput(self, n: Value) -> _int: ...
794
    def addNode(self, name: str, inputs: Sequence[Value]) -> Node: ...
795

796
# Defined in torch/csrc/jit/ir/ir.h
797
class Node:
798
    def __getitem__(self, key: str) -> Any: ...
799
    def schema(self) -> str: ...
800
    def input(self) -> Value: ...
801
    def inputs(self) -> Iterator[Value]: ...
802
    def inputsAt(self, idx: _int) -> Value: ...
803
    def inputsSize(self) -> _int: ...
804
    def output(self) -> Value: ...
805
    def outputs(self) -> Iterator[Value]: ...
806
    def outputsAt(self, idx: _int) -> Value: ...
807
    def outputsSize(self) -> _int: ...
808
    def hasMultipleOutputs(self) -> _bool: ...
809
    def blocks(self) -> List[Block]: ...
810
    def addBlock(self) -> Block: ...
811
    def mustBeNone(self) -> _bool: ...
812
    def matches(self, pattern: str) -> _bool: ...
813
    def kind(self) -> str: ...
814
    def kindOf(self, name: str) -> str: ...
815
    def addInput(self, name: str) -> Value: ...
816
    def replaceInput(self, i: _int, newValue: Value) -> Value: ...
817
    def replaceInputWith(self, from_: Value, to: Value) -> None: ...
818
    def replaceAllUsesWith(self, n: Node) -> None: ...
819
    def insertBefore(self, n: Node) -> Node: ...
820
    def insertAfter(self, n: Node) -> Node: ...
821
    def isBefore(self, n: Node) -> _bool: ...
822
    def isAfter(self, n: Node) -> _bool: ...
823
    def moveBefore(self, n: Node) -> None: ...
824
    def moveAfter(self, n: Node) -> None: ...
825
    def removeInput(self, i: _int) -> None: ...
826
    def removeAllInputs(self, i: _int) -> None: ...
827
    def hasUses(self) -> _bool: ...
828
    def eraseOutput(self, i: _int) -> None: ...
829
    def addOutput(self) -> Value: ...
830
    def scopeName(self) -> str: ...
831
    def isNondeterministic(self) -> _bool: ...
832
    def copyAttributes(self, rhs: Node) -> Node: ...
833
    def copyMetadata(self, rhs: Node) -> Node: ...
834
    def hasAttributes(self) -> _bool: ...
835
    def hasAttribute(self, name: str) -> _bool: ...
836
    def removeAttribute(self, attr: str) -> Node: ...
837
    def namedInput(self, name: str) -> Value: ...
838
    def sourceRange(self) -> SourceRange: ...
839
    def owningBlock(self) -> Block: ...
840
    def findNode(self, kind: str, recurse: _bool = True) -> Node: ...
841
    def findAllNodes(self, kind: str, recurse: _bool = True) -> List[Node]: ...
842
    def getModuleHierarchy(self) -> str: ...
843
    def prev(self) -> Node: ...
844
    def destroy(self) -> None: ...
845
    def attributeNames(self) -> List[str]: ...
846

847
    # Accessors for attributes as types.
848
    def f(self, name: str) -> _float: ...
849
    def f_(self, name: str, val: _float) -> Node: ...
850
    def fs(self, name: str) -> List[_float]: ...
851
    def fs_(self, name: str, val: List[_float]) -> Node: ...
852
    def c(self, name: str) -> complex: ...
853
    def c_(self, name: str, val: complex) -> Node: ...
854
    def s(self, name: str) -> str: ...
855
    def s_(self, name: str, val: str) -> Node: ...
856
    def ss(self, name: str) -> List[str]: ...
857
    def ss_(self, name: str, val: List[str]) -> Node: ...
858
    def i(self, name: str) -> _int: ...
859
    def i_(self, name: str, val: _int) -> Node: ...
860
    # Cannot define "is" like this because it's a reserved keyword in python.
861
    # def is(self, name: str) -> List[_int]: ...
862
    # def is_(self, name: str, val: List[_int]) -> Node: ...
863
    def g(self, name: str) -> Graph: ...
864
    def g_(self, name: str, val: Graph) -> Node: ...
865
    def gs(self, name: str) -> List[Graph]: ...
866
    def gs_(self, name: str, val: List[Graph]) -> Node: ...
867
    def ival(self, name: str) -> IValue: ...
868
    def ival_(self, name: str, val: IValue) -> Node: ...
869
    def t(self, name: str) -> Tensor: ...
870
    def t_(self, name: str, val: Tensor) -> Node: ...
871
    def ts(self, name: str) -> List[Tensor]: ...
872
    def ts_(self, name: str, val: List[Tensor]) -> Node: ...
873
    def ty(self, name: str) -> JitType: ...
874
    def ty_(self, name: str, val: JitType) -> Node: ...
875
    def tys(self, name: str) -> List[JitType]: ...
876
    def tys_(self, name: str, val: List[JitType]) -> Node: ...
877

878
# Defined in torch/torch/csrc/jit/ir/ir.h
879
class Graph:
880
    def inputs(self) -> Iterator[Value]: ...
881
    def outputs(self) -> Iterator[Value]: ...
882
    def nodes(self) -> Iterator[Node]: ...
883
    def param_node(self) -> Node: ...
884
    def return_node(self) -> Node: ...
885
    def addInput(self, name: str = "") -> Value: ...
886
    def eraseInput(self, i: _int) -> None: ...
887
    def registerOutput(self, n: Value) -> _int: ...
888
    def eraseOutput(self, i: _int) -> None: ...
889
    def create(self, name: str, args, num_outputs: _int) -> Node: ...
890
    def appendNode(self, n: Node) -> Node: ...
891
    def prependNode(self, n: Node) -> Node: ...
892
    def insertNode(self, n: Node) -> Node: ...
893
    def block(self) -> Block: ...
894
    def lint(self) -> None: ...
895
    def alias_db(self) -> AliasDb: ...
896
    def setInsertPoint(self, n: Union[Block, Node]) -> None: ...
897
    def insert_point_guard(self, n: Union[Block, Node]) -> _InsertPoint: ...
898
    def insertPoint(self) -> Node: ...
899
    def insertGraph(self, callee: Graph, inputs: List[Value]) -> List[Value]: ...
900
    def makeMultiOutputIntoTuple(self) -> None: ...
901
    def copy(self) -> Graph: ...
902

903
# Defined in torch/aten/src/ATen/core/alias_info.h
904
class AliasInfo:
905
    is_write: _bool
906
    before_set: Set[str]
907
    after_set: Set[str]
908

909
# Defined in torch/aten/src/ATen/core/function_schema.h
910
class Argument:
911
    name: str
912
    type: JitType
913
    default_value: Optional[Any]
914
    def has_default_value(self) -> _bool: ...
915
    kwarg_only: _bool
916
    is_out: _bool
917
    alias_info: Optional[AliasInfo]
918

919
class FunctionSchema:
920
    arguments: List[Argument]
921
    returns: List[Argument]
922
    name: str
923
    overload_name: str
924
    is_mutable: _bool
925

926
class _UpgraderEntry:
927
    bumped_at_version: _int
928
    upgrader_name: str
929
    old_schema: str
930
    def __init__(
931
        self,
932
        bumped_at_version: _int,
933
        upgrader_name: str,
934
        old_schema: str,
935
    ) -> None: ...
936

937
class _UpgraderRange:
938
    min_version: _int
939
    max_version: _int
940

941
def _get_max_operator_version() -> _int: ...
942
def _get_operator_version_map() -> Dict[str, List[_UpgraderEntry]]: ...
943
def _get_upgrader_ranges(name: str) -> List[_UpgraderRange]: ...
944
def _test_only_add_entry_to_op_version(op_name: str, entry: _UpgraderEntry) -> None: ...
945
def _test_only_remove_entry_to_op_version(op_name: str) -> None: ...
946

947
# Defined in torch/csrc/jit/python/script_init.cpp
948
class ScriptModuleSerializer:
949
    def __init__(self, export_writer: PyTorchFileWriter) -> None: ...
950
    def serialize(self, model: ScriptModule, script_module_id: _int) -> None: ...
951
    def write_files(self) -> None: ...
952
    def storage_context(self) -> SerializationStorageContext: ...
953

954
# Defined in torch/csrc/jit/python/script_init.cpp
955
class SerializationStorageContext:
956
    def __init__(self) -> None: ...
957
    def has_storage(self, storage: Storage) -> _bool: ...
958
    def get_or_add_storage(self, storage: Storage) -> _int: ...
959

960
# Defined in torch/csrc/jit/python/script_init.cpp
961
class DeserializationStorageContext:
962
    def __init__(self) -> None: ...
963
    def get_storage(self, name: str, dtype: _dtype) -> Tensor: ...
964
    def has_storage(self, name: str) -> _bool: ...
965
    def add_storage(self, name: str, tensor: Tensor) -> _int: ...
966

967
# Defined in torch/csrc/jit/python/script_init.cpp
968
class ConcreteModuleTypeBuilder:
969
    def __init__(self, obj: Any) -> None: ...
970
    def set_module_dict(self): ...
971
    def set_module_list(self): ...
972
    def set_parameter_list(self): ...
973
    def set_parameter_dict(self): ...
974
    def add_attribute(
975
        self,
976
        name: str,
977
        ty: JitType,
978
        is_param: _bool,
979
        is_buffer: _bool,
980
    ): ...
981
    def add_module(self, name: str, meta: ConcreteModuleType): ...
982
    def add_constant(self, name: str, value: Any): ...
983
    def add_overload(self, method_name: str, overloaded_method_names: List[str]): ...
984
    def add_builtin_function(self, name: str, symbol_name: str): ...
985
    def add_failed_attribute(self, name: str, failure_reason: str): ...
986
    def add_function_attribute(
987
        self,
988
        name: str,
989
        ty: JitType,
990
        func: Callable[..., Any],
991
    ): ...
992
    def add_ignored_attribute(self, name: str): ...
993
    def add_ignored_attributes(self, names: List[str]): ...
994
    def add_forward_hook(self, hook: Callable[..., Any]): ...
995
    def add_forward_pre_hook(self, pre_hook: Callable[..., Any]): ...
996

997
class ConcreteModuleType:
998
    def get_constants(self) -> Dict[str, Any]: ...
999
    def equals(self, other: ConcreteModuleType) -> _bool: ...
1000
    @staticmethod
1001
    def from_jit_type(ty: JitType) -> ConcreteModuleType: ...
1002

1003
class CallStack:
1004
    def __init__(self, name: str, range: SourceRange): ...
1005

1006
class ErrorReport:
1007
    def __init__(self, range: SourceRange) -> None: ...
1008
    def what(self) -> str: ...
1009
    @staticmethod
1010
    def call_stack() -> str: ...
1011

1012
class CompilationUnit:
1013
    def __init__(self, lang: str = ..., _frames_up: _int = ...) -> None: ...
1014
    def find_function(self, name: str) -> ScriptFunction: ...
1015
    def __getattr__(self, name: str) -> ScriptFunction: ...
1016
    def define(
1017
        self,
1018
        script: str,
1019
        rcb: ResolutionCallback = ...,
1020
        _frames_up: _int = ...,
1021
    ): ...
1022
    def get_interface(self, name: str) -> InterfaceType: ...
1023
    def get_functions(self) -> List[ScriptFunction]: ...
1024
    def create_function(
1025
        self,
1026
        name: str,
1027
        graph: Graph,
1028
        shouldMangle: _bool = ...,
1029
    ) -> ScriptFunction: ...
1030
    def get_class(self, name: str) -> ClassType: ...
1031

1032
class ScriptObject:
1033
    def setattr(self, name: str, value: Any): ...
1034

1035
class ScriptModule(ScriptObject):
1036
    def _method_names(self) -> List[str]: ...
1037
    def _get_method(self, name: str) -> ScriptMethod: ...
1038

1039
class LiteScriptModule:
1040
    def __call__(self, *input): ...
1041
    def find_method(self, method_name: str): ...
1042
    def forward(self, *input) -> List[str]: ...
1043
    def run_method(self, method_name: str, *input): ...
1044

1045
# NOTE: switch to collections.abc.Callable in python 3.9
1046
class ScriptFunction(Generic[P, ReturnVal]):
1047
    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ReturnVal: ...
1048
    def save(self, filename: str, _extra_files: Dict[str, bytes]) -> None: ...
1049
    def save_to_buffer(self, _extra_files: Dict[str, bytes]) -> bytes: ...
1050
    @property
1051
    def graph(self) -> Graph: ...
1052
    def inlined_graph(self) -> Graph: ...
1053
    def schema(self) -> FunctionSchema: ...
1054
    def code(self) -> str: ...
1055
    def name(self) -> str: ...
1056
    @property
1057
    def qualified_name(self) -> str: ...
1058

1059
# NOTE: switch to collections.abc.Callable in python 3.9
1060
class ScriptMethod(Generic[P, ReturnVal]):
1061
    graph: Graph
1062
    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> ReturnVal: ...
1063
    @property
1064
    def owner(self) -> ScriptModule: ...
1065
    @property
1066
    def name(self) -> str: ...
1067

1068
class ScriptDict(Generic[K, T]):
1069
    def __init__(self, dict: Dict[K, T]) -> None: ...
1070
    def __len__(self) -> _int: ...
1071
    def __contains__(self, key: K) -> _bool: ...
1072
    def __getitem__(self, key: K) -> T: ...
1073
    def __setitem__(self, key: K, value: T) -> None: ...
1074
    def __delitem__(self, key: K) -> None: ...
1075
    def __iter__(self) -> Iterator[K]: ...
1076
    def items(self) -> Iterator[tuple[K, T]]: ...
1077
    def keys(self) -> Iterator[K]: ...
1078

1079
class ScriptList(Generic[T]):
1080
    def __init__(self, list: List[T]) -> None: ...
1081
    def __len__(self) -> _int: ...
1082
    def __contains__(self, item: T) -> _bool: ...
1083
    @overload
1084
    def __getitem__(self, idx: _int) -> T: ...
1085
    @overload
1086
    def __getitem__(self, idx: slice) -> ScriptList[T]: ...
1087
    @overload
1088
    def __setitem__(self, idx: _int, value: T) -> None: ...
1089
    @overload
1090
    def __setitem__(self, idx: slice, value: List[T]) -> None: ...
1091
    def __delitem__(self, idx: _int) -> None: ...
1092
    def __iter__(self) -> Iterator[T]: ...
1093
    def count(self, value: T) -> _int: ...
1094
    def remove(self, value: T) -> None: ...
1095
    def append(self, value: T) -> None: ...
1096
    def clear(self) -> None: ...
1097
    @overload
1098
    def extend(self, values: List[T]) -> None: ...
1099
    @overload
1100
    def extend(self, values: Iterable[T]) -> None: ...
1101
    @overload
1102
    def pop(self) -> T: ...
1103
    @overload
1104
    def pop(self, idx: _int) -> T: ...
1105

1106
class ModuleDict:
1107
    def __init__(self, mod: ScriptModule) -> None: ...
1108
    def items(self) -> List[Tuple[str, Any]]: ...
1109

1110
class ParameterDict:
1111
    def __init__(self, mod: ScriptModule) -> None: ...
1112

1113
class BufferDict:
1114
    def __init__(self, mod: ScriptModule) -> None: ...
1115

1116
# Defined in torch/csrc/jit/api/module.h
1117
class Module: ...
1118

1119
# Defined in torch/csrc/Module.cpp
1120
def _initExtension(shm_manager_path: str) -> None: ...  # THPModule_initExtension
1121
def _autograd_init() -> _bool: ...  # THPAutograd_initExtension
1122
def _add_docstr(obj: T, doc_obj: str) -> T: ...  # THPModule_addDocStr
1123
def _init_names(arg: Sequence[_Type]) -> None: ...  # THPModule_initNames
1124
def _has_distributed() -> _bool: ...  # THPModule_hasDistributed
1125
def _set_default_tensor_type(type) -> None: ...  # THPModule_setDefaultTensorType
1126
def _set_default_dtype(d: _dtype) -> None: ...  # THPModule_setDefaultDtype
1127
def _infer_size(arg1: Size, arg2: Size) -> Size: ...  # THPModule_inferSize
1128
def _crash_if_csrc_asan() -> _int: ...  # THPModule_crashIfCsrcASAN
1129
def _crash_if_csrc_ubsan() -> _int: ...  # THPModule_crashIfCsrcUBSAN
1130
def _crash_if_aten_asan() -> _int: ...  # THPModule_crashIfATenASAN
1131
def _show_config() -> str: ...  # THPModule_showConfig
1132
def _cxx_flags() -> str: ...  # THPModule_cxxFlags
1133
def _parallel_info() -> str: ...  # THPModule_parallelInfo
1134
def _get_cpu_capability() -> str: ...  # THPModule_getCpuCapability
1135
def _set_backcompat_broadcast_warn(
1136
    arg: _bool,
1137
) -> None: ...  # THPModule_setBackcompatBroadcastWarn
1138
def _get_backcompat_broadcast_warn() -> _bool: ...  # THPModule_getBackcompatBroadcastWarn
1139
def _set_backcompat_keepdim_warn(
1140
    arg: _bool,
1141
) -> None: ...  # THPModule_setBackcompatKeepdimWarn
1142
def _get_backcompat_keepdim_warn() -> _bool: ...  # THPModule_getBackcompatKeepdimWarn
1143
def get_num_thread() -> _int: ...  # THPModule_getNumThreads
1144
def set_num_threads(nthreads: _int) -> None: ...  # THPModule_setNumThreads
1145
def get_num_interop_threads() -> _int: ...  # THPModule_getNumInteropThreads
1146
def set_num_interop_threads(
1147
    nthreads: _int,
1148
) -> None: ...  # THPModule_setNumInteropThreads
1149
def _get_cudnn_enabled() -> _bool: ...  # THPModule_userEnabledCuDNN
1150
def _set_cudnn_enabled(arg: _bool) -> None: ...  # THPModule_setUserEnabledCuDNN
1151
def _get_flash_sdp_enabled() -> _bool: ...  # THPModule_userEnabledFusedSDP
1152
def _set_sdp_use_flash(arg: _bool) -> None: ...  # THPModule_setSDPUseFlash
1153
def _get_mem_efficient_sdp_enabled() -> _bool: ...  # THPModule_userEnabledMathSDP
1154
def _set_sdp_use_mem_efficient(
1155
    arg: _bool,
1156
) -> None: ...  # THPModule_setSDPUseMemEfficient
1157
def _get_math_sdp_enabled() -> _bool: ...  # THPModule_userEnabledMathSDP
1158
def _set_sdp_use_math(arg: _bool) -> None: ...  # THPModule_setSDPUseMath
1159
def _get_overrideable_sdp_enabled() -> _bool: ...  # THPModule_userEnabledOverrideableSDP
1160
def _set_sdp_use_overrideable(arg: _bool) -> None: ...  # THPModule_setSDPUseOverrideable
1161
def _get_cudnn_sdp_enabled() -> _bool: ...  # THPModule_userEnabledMathSDP
1162
def _set_sdp_use_cudnn(arg: _bool) -> None: ...  # THPModule_setSDPUseMath
1163
def _get_mkldnn_enabled() -> _bool: ...  # THPModule_userEnabledMkldnn
1164
def _set_mkldnn_enabled(arg: _bool) -> None: ...  # THPModule_setUserEnabledMkldnn
1165
def _get_cudnn_benchmark() -> _bool: ...  # THPModule_benchmarkCuDNN
1166
def _set_cudnn_benchmark(arg: _bool) -> None: ...  # THPModule_setBenchmarkCuDNN
1167
def _get_cudnn_deterministic() -> _bool: ...  # THPModule_deterministicCuDNN
1168
def _set_cudnn_deterministic(arg: _bool) -> None: ...  # THPModule_setDeterministicCuDNN
1169
def _get_mkldnn_deterministic() -> _bool: ...  # THPModule_deterministicMkldnn
1170
def _set_mkldnn_deterministic(arg: _bool) -> None: ...  # THPModule_setDeterministicMkldnn
1171
def _get_deterministic_algorithms() -> _bool: ...  # THPModule_deterministicAlgorithms
1172
def _get_deterministic_algorithms_warn_only() -> _bool: ...  # THPModule_deterministicAlgorithmsWarnOnly
1173
def _set_deterministic_algorithms(
1174
    mode: _bool,
1175
    *,
1176
    warn_only: _bool = ...,
1177
) -> None: ...  # THPModule_setDeterministicAlgorithms
1178
def _get_deterministic_fill_uninitialized_memory() -> _bool: ...  # THPModule_deterministicFillUninitializedMemory
1179
def _set_deterministic_fill_uninitialized_memory(arg: _bool) -> None: ...  # THPModule_setDeterministicFillUninitializedMemory
1180
def _get_nnpack_enabled() -> _bool: ...  # THPModule_userEnabledNNPACK
1181
def _set_nnpack_enabled(arg: _bool) -> None: ...  # THPModule_setUserEnabledNNPACK
1182
def _get_warnAlways() -> _bool: ...  # THPModule_warnAlways
1183
def _set_warnAlways(arg: _bool) -> None: ...  # THPModule_setWarnAlways
1184
def _get_cudnn_allow_tf32() -> _bool: ...  # THPModule_allowTF32CuDNN
1185
def _set_cudnn_allow_tf32(arg: _bool) -> None: ...  # THPModule_setAllowTF32CuDNN
1186
def _get_cublas_allow_tf32() -> _bool: ...  # THPModule_allowTF32CuBLAS
1187
def _set_cublas_allow_tf32(arg: _bool) -> None: ...  # THPModule_setAllowTF32CuBLAS
1188
def _get_float32_matmul_precision() -> str: ...  # THPModule_float32MatmulPrecision
1189
def _set_float32_matmul_precision(
1190
    arg: str,
1191
) -> None: ...  # THPModule_setFloat32MatmulPrecision
1192
def _get_cublas_allow_fp16_reduced_precision_reduction() -> _bool: ...  # THPModule_allowFP16ReductionCuBLAS
1193
def _set_cublas_allow_fp16_reduced_precision_reduction(
1194
    arg: _bool,
1195
) -> None: ...  # THPModule_setAllowFP16ReductionCuBLAS
1196
def _get_cublas_allow_bf16_reduced_precision_reduction() -> _bool: ...  # THPModule_allowBF16ReductionCuBLAS
1197
def _set_cublas_allow_bf16_reduced_precision_reduction(
1198
    arg: _bool,
1199
) -> None: ...  # THPModule_setAllowBF16ReductionCuBLAS
1200
def _set_conj(x: Tensor, conj: _bool) -> None: ...
1201
def _set_neg(x: Tensor, neg: _bool) -> None: ...
1202
def _set_meta_in_tls_dispatch_include(meta_in_tls: _bool) -> None: ...
1203
def _meta_in_tls_dispatch_include() -> _bool: ...
1204
def _stash_obj_in_tls(key: str, arg: Any) -> None: ...
1205
def _get_obj_in_tls(key: str) -> Any: ...
1206
def _is_key_in_tls(key: str) -> _bool: ...
1207
def _select_batch_norm_backend(*args, **kwargs) -> BatchNormBackend: ...
1208
def _select_conv_backend(*args, **kwargs) -> ConvBackend: ...
1209
def _conv_determine_backend_memory_format(
1210
    input: Tensor,
1211
    weight: Tensor,
1212
    backend: ConvBackend,
1213
) -> memory_format: ...
1214
def _has_storage(x: Tensor) -> _bool: ...
1215
def _construct_storage_from_data_pointer(data_ptr: _int, device: torch.device, size: _int) -> Storage: ...
1216
def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ...
1217
def _group_tensors_by_device_and_dtype(nested_tensorlists: List[List[Optional[Tensor]]], with_indices: _bool = False) -> Dict[Tuple[torch.device, torch.dtype], Tuple[List[List[Optional[Tensor]]], List[_int]]]: ...
1218

1219
# NB: There is no Capsule type in typing, see
1220
# https://code.activestate.com/lists/python-dev/139675/
1221
def _to_dlpack(data: Tensor) -> Any: ...  # THPModule_toDLPack
1222
def _from_dlpack(data: Any) -> Tensor: ...  # THPModule_fromDLPack
1223
def _get_cpp_backtrace(
1224
    frames_to_skip: _int,
1225
    maximum_number_of_frames: _int,
1226
) -> str: ...  # THPModule_getCppBacktrace
1227
def set_flush_denormal(arg: _bool) -> _bool: ...  # THPModule_setFlushDenormal
1228
def get_default_dtype() -> _dtype: ...  # THPModule_getDefaultDtype
1229
def _get_default_device() -> str: ...  # THPModule_getDefaultDevice
1230
def _get_qengine() -> _int: ...  # THPModule_qEngine
1231
def _set_qengine(qengine: _int) -> None: ...  # THPModule_setQEngine
1232
def _supported_qengines() -> List[_int]: ...  # THPModule_supportedQEngines
1233
def _is_xnnpack_enabled() -> _bool: ...  # THPModule_isEnabledXNNPACK
1234
def _check_sparse_tensor_invariants() -> _bool: ...  # THPModule_checkSparseTensorInvariants
1235
def _set_check_sparse_tensor_invariants(
1236
    arg: _bool,
1237
) -> None: ...  # THPModule_setCheckSparseTensorInvariants
1238
def _set_default_mobile_cpu_allocator() -> None: ...  # THPModule_setDefaultMobileCPUAllocator
1239
def _unset_default_mobile_cpu_allocator() -> None: ...  # THPModule_unsetDefaultMobileCPUAllocator
1240
def _is_torch_function_enabled() -> _bool: ...  # THPModule_isEnabledTorchFunction
1241
def _is_torch_function_all_disabled() -> _bool: ...  # THPModule_isAllDisabledTorchFunction
1242
def _has_torch_function(
1243
    args: Iterable[Any],
1244
) -> _bool: ...  # THPModule_has_torch_function
1245
def _has_torch_function_unary(Any) -> _bool: ...  # THPModule_has_torch_function_unary
1246
def _has_torch_function_variadic(
1247
    *args: Any,
1248
) -> _bool: ...  # THPModule_has_torch_function_variadic
1249
def _vmapmode_increment_nesting() -> _int: ...  # THPModule_vmapmode_increment_nesting
1250
def _vmapmode_decrement_nesting() -> _int: ...  # THPModule_vmapmode_decrement_nesting
1251
def _log_api_usage_once(str) -> None: ...  # LogAPIUsageOnceFromPython
1252
def _log_api_usage_metadata(event: str, metadata_map: Dict[str, str]) -> None: ...  # LogAPIUsageMetadataFromPython
1253
def _demangle(str) -> str: ...  # c10::demangle
1254
def _disabled_torch_function_impl(
1255
    func: Callable,
1256
    types: Iterable[_Type],
1257
    args: Tuple,
1258
    kwargs: Dict,
1259
) -> Any: ...  # THPModule_disable_torch_function
1260
def _disabled_torch_dispatch_impl(
1261
    func: Callable,
1262
    types: Iterable[_Type],
1263
    args: Tuple,
1264
    kwargs: Dict,
1265
) -> Any: ...  # THPModule_disable_dispatch_function
1266
def _get_linalg_preferred_backend() -> torch._C._LinalgBackend: ...
1267
def _set_linalg_preferred_backend(arg: torch._C._LinalgBackend): ...
1268

1269
class _LinalgBackend:
1270
    Default: _LinalgBackend
1271
    Cusolver: _LinalgBackend
1272
    Magma: _LinalgBackend
1273

1274
class BatchNormBackend(Enum): ...
1275

1276
def _get_blas_preferred_backend() -> torch._C._BlasBackend: ...
1277
def _set_blas_preferred_backend(arg: torch._C._BlasBackend): ...
1278

1279
class _BlasBackend:
1280
    Cublas: _BlasBackend
1281
    Cublaslt: _BlasBackend
1282

1283
class ConvBackend(Enum): ...
1284

1285
class Tag(Enum):
1286
    ${tag_attributes}
1287

1288
# Defined in `valgrind.h` and `callgrind.h` respectively.
1289
def _valgrind_supported_platform() -> _bool: ...  # NVALGRIND
1290
def _valgrind_toggle() -> None: ...  # CALLGRIND_TOGGLE_COLLECT
1291
def _valgrind_toggle_and_dump_stats() -> None: ...  # CALLGRIND_TOGGLE_COLLECT and CALLGRIND_DUMP_STATS
1292

1293
has_openmp: _bool
1294
has_mkl: _bool
1295
_has_mps: _bool
1296
has_lapack: _bool
1297
_has_cuda: _bool
1298
_has_magma: _bool
1299
_has_xpu: _bool
1300
_has_mkldnn: _bool
1301
_has_cudnn: _bool
1302
_has_cusparselt: _bool
1303
has_spectral: _bool
1304
_GLIBCXX_USE_CXX11_ABI: _bool
1305
default_generator: Generator
1306

1307
# Defined in torch/csrc/autograd/init.cpp
1308
def _set_grad_enabled(enabled: _bool) -> None: ...
1309
def is_grad_enabled() -> _bool: ...
1310
def _set_fwd_grad_enabled(enabled: _bool) -> None: ...
1311
def _is_fwd_grad_enabled() -> _bool: ...
1312
def is_inference_mode_enabled() -> _bool: ...
1313
@overload
1314
def set_autocast_enabled(device_type: str, enabled: _bool) -> None: ...
1315
@overload
1316
def set_autocast_enabled(enabled: _bool) -> None: ...
1317
@overload
1318
def is_autocast_enabled(device_type: str) -> _bool: ...
1319
@overload
1320
def is_autocast_enabled() -> _bool: ...
1321
def set_autocast_dtype(device_type: str, dtype: _dtype) -> None: ...
1322
def get_autocast_dtype(device_type: str) -> _dtype: ...
1323
def clear_autocast_cache() -> None: ...
1324
def set_autocast_cpu_enabled(enabled: _bool) -> None: ...
1325
def is_autocast_cpu_enabled() -> _bool: ...
1326
def _is_any_autocast_enabled() -> _bool: ...
1327
def _is_autocast_available(device_type: str) -> _bool: ...
1328
def set_autocast_cpu_dtype(dtype: _dtype) -> None: ...
1329
def set_autocast_gpu_dtype(dtype: _dtype) -> None: ...
1330
def get_autocast_cpu_dtype() -> _dtype: ...
1331
def get_autocast_gpu_dtype() -> _dtype: ...
1332
def autocast_increment_nesting() -> _int: ...
1333
def autocast_decrement_nesting() -> _int: ...
1334
def is_autocast_cache_enabled() -> _bool: ...
1335
def set_autocast_cache_enabled(enabled: _bool) -> None: ...
1336
def _increment_version(tensors: Iterable[Tensor]) -> None: ...
1337
def set_anomaly_enabled(enabled: _bool, check_nan: _bool = True) -> None: ...
1338
def is_anomaly_enabled() -> _bool: ...
1339
def is_anomaly_check_nan_enabled() -> _bool: ...
1340
def _is_multithreading_enabled() -> _bool: ...
1341
def _set_multithreading_enabled(enabled: _bool) -> None: ...
1342
def _set_view_replay_enabled(enabled: _bool) -> None: ...
1343
def _is_view_replay_enabled() -> _bool: ...
1344
def _enter_dual_level() -> _int: ...
1345
def _exit_dual_level(level: _int) -> None: ...
1346
def _make_dual(tensor: Tensor, tangent: Tensor, level: _int) -> Tensor: ...
1347
def _unpack_dual(tensor: Tensor, level: _int) -> Tensor: ...
1348
def __set_forward_AD_enabled(enabled: _bool) -> None: ...
1349
def __is_forward_AD_enabled() -> _bool: ...
1350
def _register_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ...
1351
def _reset_default_hooks() -> None: ...
1352
def _is_torch_function_mode_enabled() -> _bool: ...
1353
def _set_torch_function_mode(cls: Any) -> None: ...
1354
def _push_on_torch_function_stack(cls: Any) -> None: ...
1355
def _pop_torch_function_stack() -> Any: ...
1356
def _get_function_stack_at(idx: _int) -> Any: ...
1357
def _len_torch_function_stack() -> _int: ...
1358
def _set_torch_dispatch_mode(cls: Any) -> None: ...
1359
def _push_on_torch_dispatch_stack(cls: TorchDispatchMode) -> None: ...
1360
def _pop_torch_dispatch_stack(mode_key: Optional[torch._C._TorchDispatchModeKey] = None) -> Any: ...
1361
def _get_dispatch_mode(mode_key: Optional[torch._C._TorchDispatchModeKey]) -> Any: ...
1362
def _unset_dispatch_mode(mode: torch._C._TorchDispatchModeKey) -> Optional[TorchDispatchMode]: ...
1363
def _set_dispatch_mode(mode: TorchDispatchMode) -> None: ...
1364
def _get_dispatch_stack_at(idx: _int) -> Any: ...
1365
def _len_torch_dispatch_stack() -> _int: ...
1366
def _activate_gpu_trace() -> None: ...
1367

1368
class _DisableTorchDispatch:
1369
    def __init__(self): ...
1370
    def __enter__(self): ...
1371
    def __exit__(self, exc_type, exc_value, traceback): ...
1372

1373
class _EnableTorchFunction:
1374
    def __init__(self): ...
1375
    def __enter__(self): ...
1376
    def __exit__(self, exc_type, exc_value, traceback): ...
1377

1378
class _EnablePythonDispatcher:
1379
    def __init__(self): ...
1380
    def __enter__(self): ...
1381
    def __exit__(self, exc_type, exc_value, traceback): ...
1382

1383
class _DisablePythonDispatcher:
1384
    def __init__(self): ...
1385
    def __enter__(self): ...
1386
    def __exit__(self, exc_type, exc_value, traceback): ...
1387

1388
class _EnablePreDispatch:
1389
    def __init__(self): ...
1390
    def __enter__(self): ...
1391
    def __exit__(self, exc_type, exc_value, traceback): ...
1392

1393
class _DisableFuncTorch:
1394
    def __init__(self): ...
1395
    def __enter__(self): ...
1396
    def __exit__(self, exc_type, exc_value, traceback): ...
1397

1398
class _DisableAutocast:
1399
    def __init__(self): ...
1400
    def __enter__(self): ...
1401
    def __exit__(self, exc_type, exc_value, traceback): ...
1402

1403
class _InferenceMode:
1404
    def __init__(self, enabled: _bool): ...
1405
    def __enter__(self): ...
1406
    def __exit__(self, exc_type, exc_value, traceback): ...
1407

1408
def _set_autograd_fallback_mode(mode: str) -> None: ...
1409
def _get_autograd_fallback_mode() -> str: ...
1410

1411
# Defined in torch/csrc/jit/python/script_init.cpp
1412
class LoggerBase: ...
1413
class NoopLogger(LoggerBase): ...
1414
class LockingLogger(LoggerBase): ...
1415

1416
class AggregationType(Enum):
1417
    SUM = 0
1418
    AVG = 1
1419

1420
class FileCheck:
1421
    def run(self, test_string: str) -> None: ...
1422
    def check(self, test_string: str) -> FileCheck: ...
1423
    def check_not(self, test_string: str) -> FileCheck: ...
1424
    def check_same(self, test_string: str) -> FileCheck: ...
1425
    def check_next(self, test_string: str) -> FileCheck: ...
1426
    def check_count(
1427
        self,
1428
        test_string: str,
1429
        count: _int,
1430
        exactly: _bool = False,
1431
    ) -> FileCheck: ...
1432
    def check_dag(self, test_string: str) -> FileCheck: ...
1433
    def check_source_highlighted(self, test_string: str) -> FileCheck: ...
1434
    def check_regex(self, test_string: str) -> FileCheck: ...
1435

1436
# Defined in torch/csrc/jit/python/init.cpp
1437
class PyTorchFileReader:
1438
    @overload
1439
    def __init__(self, name: str) -> None: ...
1440
    @overload
1441
    def __init__(self, buffer: BinaryIO) -> None: ...
1442
    def get_record(self, name: str) -> bytes: ...
1443
    def serialization_id(self) -> str: ...
1444

1445
class PyTorchFileWriter:
1446
    @overload
1447
    def __init__(self, name: str) -> None: ...
1448
    @overload
1449
    def __init__(self, buffer: BinaryIO) -> None: ...
1450
    def write_record(self, name: str, data: Union[Storage, bytes, _int], size: _int) -> None: ...
1451
    def write_end_of_file(self) -> None: ...
1452
    def set_min_version(self, version: _int) -> None: ...
1453
    def get_all_written_records(self) -> List[str]: ...
1454
    def archive_name(self) -> str: ...
1455
    def serialization_id(self) -> str: ...
1456

1457
def _jit_get_inline_everything_mode() -> _bool: ...
1458
def _jit_set_inline_everything_mode(enabled: _bool) -> None: ...
1459
def _jit_get_logging_option() -> str: ...
1460
def _jit_set_logging_option(option: str) -> None: ...
1461
def _jit_set_logging_stream(stream_name: str) -> None: ...
1462
def _jit_pass_cse(Graph) -> _bool: ...
1463
def _jit_pass_dce(Graph) -> None: ...
1464
def _jit_pass_lint(Graph) -> None: ...
1465

1466
# Defined in torch/csrc/jit/python/python_custom_class.cpp
1467
def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ...
1468

1469
# Defined in torch/csrc/Module.cpp
1470
def _rename_privateuse1_backend(backend: str) -> None: ...
1471
def _get_privateuse1_backend_name() -> str: ...
1472

1473
# Defined in torch/csrc/Generator.cpp
1474
class Generator:
1475
    device: _device
1476
    def __init__(self, device: Optional[DeviceLikeType] = None) -> None: ...
1477
    def __reduce__(self) -> Tuple[_Type[Generator], Tuple[_device], Tuple[_int, Optional[_int], Tensor]]: ...
1478
    def __setstate__(self, state: Tuple[_int, Optional[_int], Tensor]) -> None: ...
1479
    def get_state(self) -> Tensor: ...
1480
    def set_state(self, _new_state: Tensor) -> Generator: ...
1481
    def clone_state(self) -> Generator: ...
1482
    def graphsafe_get_state(self) -> Generator: ...
1483
    def graphsafe_set_state(self, _new_state: Generator) -> Generator: ...
1484
    def set_offset(self, offset: _int) -> Generator: ...
1485
    def get_offset(self) -> _int: ...
1486
    def manual_seed(self, seed: _int) -> Generator: ...
1487
    def seed(self) -> _int: ...
1488
    def initial_seed(self) -> _int: ...
1489

1490
# Defined in torch/csrc/utils/python_dispatch.cpp
1491

1492
class _DispatchOperatorHandle:
1493
    def schema(self) -> FunctionSchema: ...
1494
    def debug(self) -> str: ...
1495

1496
class _DispatchModule:
1497
    def reset(self) -> None: ...
1498
    def def_(self, schema: str, alias: str = "") -> _DispatchModule: ...
1499
    def def_legacy(self, schema: str) -> _DispatchModule: ...
1500
    def def_name_t_t(
1501
        self,
1502
        name: str,
1503
        dispatch: str,
1504
        debug: str = "default_def_name_t_t",
1505
    ) -> _DispatchModule: ...
1506
    def def_schema_t_t(
1507
        self,
1508
        schema: str,
1509
        dispatch: str,
1510
        alias: str,
1511
        debug: str = "default_def_schema_t_t",
1512
    ) -> _DispatchModule: ...
1513
    def impl_t_t(
1514
        self,
1515
        name: str,
1516
        dispatch: str,
1517
        debug: str = "impl_t_t",
1518
    ) -> _DispatchModule: ...
1519
    def impl_with_aoti_compile(
1520
        self,
1521
        ns: str,
1522
        op_name_with_overload: str,
1523
        dispatch: _dispatchkey
1524
    ) -> None: ...
1525
    def impl(self, name: str, dispatch: _dispatchkey, func: Callable) -> None: ...
1526
    def define(self, schema: str, alias: str = "") -> str: ...
1527
    def fallback_fallthrough(self, dispatch: str = "") -> _DispatchModule: ...
1528
    def fallback(self, dispatch: _dispatchkey, func: Callable, with_keyset: _bool = False) -> None: ...
1529

1530
_after_ADInplaceOrView_keyset: DispatchKeySet
1531
_after_autograd_keyset: DispatchKeySet
1532

1533
def _dispatch_library(
1534
    kind: str,
1535
    name: str,
1536
    dispatch: str,
1537
    file: str = "",
1538
    linenum: Any = 0,
1539
) -> _DispatchModule: ...
1540
def _dispatch_dump(name: str) -> str: ...
1541
def _dispatch_dump_table(name: str) -> str: ...
1542
def _dispatch_check_invariants(name: str) -> None: ...
1543
def _dispatch_check_all_invariants() -> None: ...
1544
def _dispatch_call_boxed(handle: _DispatchOperatorHandle, *args, **kwargs) -> Any: ...
1545
def _dispatch_find_schema_or_throw(name: str, overload_name: str) -> _DispatchOperatorHandle: ...
1546
def _dispatch_set_report_error_callback(handle: _DispatchOperatorHandle, callback: Callable) -> None: ...
1547
def _dispatch_has_kernel(name: str) -> _bool: ...
1548
def _dispatch_has_kernel_for_dispatch_key(
1549
    name: str,
1550
    dispatch: _dispatchkey,
1551
) -> _bool: ...
1552
def _dispatch_has_kernel_for_any_dispatch_key(
1553
    name: str,
1554
    dispatch_key_set: DispatchKeySet,
1555
) -> _bool: ...
1556
def _dispatch_kernel_for_dispatch_key_is_fallthrough(
1557
    name: str,
1558
    dispatch: _dispatchkey,
1559
) -> _bool: ...
1560
def _dispatch_has_computed_kernel_for_dispatch_key(
1561
    name: str,
1562
    dispatch: _dispatchkey,
1563
) -> _bool: ...
1564
def _dispatch_find_dangling_impls() -> List[str]: ...
1565
def _dispatch_get_all_op_names() -> List[str]: ...
1566
def _dispatch_tls_set_dispatch_key_excluded(
1567
    dispatch: _dispatchkey,
1568
    val: _bool,
1569
) -> None: ...
1570
def _dispatch_tls_is_dispatch_key_excluded(dispatch: _dispatchkey) -> _bool: ...
1571
def _dispatch_tls_set_dispatch_key_included(
1572
    dispatch: _dispatchkey,
1573
    val: _bool,
1574
) -> None: ...
1575
def _dispatch_tls_is_dispatch_key_included(dispatch: _dispatchkey) -> _bool: ...
1576
def _dispatch_isTensorSubclassLike(tensor: Tensor) -> _bool: ...
1577
def _dispatch_key_name(dispatch: _dispatchkey) -> str: ...
1578
def _dispatch_key_for_device(device_type: str) -> str: ...
1579
def _parse_dispatch_key(key: str) -> Optional[DispatchKey]: ...
1580
def _dispatch_key_parse(dispatch: _dispatchkey) -> DispatchKey: ...
1581
def _dispatch_num_backends() -> _int: ...
1582
def _dispatch_pystub(name: str, overload: str) -> Optional[Tuple[str, str]]: ...
1583
def _dispatch_is_alias_key(dispatch: _dispatchkey) -> _bool: ...
1584
def _functionality_to_backend_keys(dispatch: _dispatchkey) -> List[DispatchKey]: ...
1585
def _functionalization_reapply_views_tls() -> _bool: ...
1586
def _only_lift_cpu_tensors() -> _bool: ...
1587
def _set_only_lift_cpu_tensors(value: _bool) -> None: ...
1588
def _set_throw_on_mutable_data_ptr(tensor: Tensor) -> None: ...
1589
def _set_warn_deprecated_on_mutable_data_ptr(tensor: Tensor) -> None: ...
1590

1591
class DispatchKey(Enum):
1592
    ${dispatch_key_hints}
1593

1594
class DispatchKeySet:
1595
    def __init__(self, key: DispatchKey) -> None: ...
1596
    def __or__(self, other: DispatchKeySet) -> DispatchKeySet: ...
1597
    def __sub__(self, other: DispatchKeySet) -> DispatchKeySet: ...
1598
    def __and__(self, other: DispatchKeySet) -> DispatchKeySet: ...
1599
    def raw_repr(self) -> _int: ...
1600
    def highestPriorityTypeId(self) -> DispatchKey: ...
1601
    def has(self, k: _dispatchkey) -> _bool: ...
1602
    def add(self, k: _dispatchkey) -> DispatchKeySet: ...
1603
    def remove(self, k: _dispatchkey) -> DispatchKeySet: ...
1604
    def __repr__(self) -> str: ...
1605

1606
_dispatch_autogradother_backends: DispatchKeySet
1607
_additional_keys_to_prop_for_wrapper_tensors: DispatchKeySet
1608

1609
def _dispatch_has_backend_fallback(dispatch: _dispatchkey) -> _bool: ...
1610
def _dispatch_keyset_full_after(t: _dispatchkey) -> DispatchKeySet: ...
1611
def _dispatch_keyset_full() -> DispatchKeySet: ...
1612
def _dispatch_keyset_to_string(keyset: DispatchKeySet) -> str: ...
1613
def _dispatch_get_backend_keyset_from_autograd(
1614
    dispatch: _dispatchkey,
1615
) -> DispatchKeySet: ...
1616
def _dispatch_keys(tensor: Tensor) -> DispatchKeySet: ...
1617
def _dispatch_tls_local_exclude_set() -> DispatchKeySet: ...
1618
def _dispatch_tls_local_include_set() -> DispatchKeySet: ...
1619
def _dispatch_is_included_in_alias(
1620
    dispatch_a: _dispatchkey,
1621
    dispatch_b: _dispatchkey,
1622
) -> _bool: ...
1623
def _propagate_xla_data(a: Tensor, b: Tensor) -> None: ...
1624
def _replace_(a: Tensor, b: Tensor) -> None: ...
1625
def _commit_update(a: Tensor) -> None: ...
1626

1627
class _ExcludeDispatchKeyGuard:
1628
    def __init__(self, keyset: DispatchKeySet): ...
1629
    def __enter__(self): ...
1630
    def __exit__(self, exc_type, exc_value, traceback): ...
1631

1632
class _IncludeDispatchKeyGuard:
1633
    def __init__(self, k: DispatchKey): ...
1634
    def __enter__(self): ...
1635
    def __exit__(self, exc_type, exc_value, traceback): ...
1636

1637
class _ForceDispatchKeyGuard:
1638
    def __init__(self, include: DispatchKeySet, exclude: DispatchKeySet): ...
1639
    def __enter__(self): ...
1640
    def __exit__(self, exc_type, exc_value, traceback): ...
1641

1642
class _PreserveDispatchKeyGuard:
1643
    def __init__(self): ...
1644
    def __enter__(self): ...
1645
    def __exit__(self, exc_type, exc_value, traceback): ...
1646

1647
class _AutoDispatchBelowAutograd:
1648
    def __init__(self): ...
1649
    def __enter__(self): ...
1650
    def __exit__(self, exc_type, exc_value, traceback): ...
1651

1652
class _AutoDispatchBelowADInplaceOrView:
1653
    def __init__(self): ...
1654
    def __enter__(self): ...
1655
    def __exit__(self, exc_type, exc_value, traceback): ...
1656

1657
def _dispatch_print_registrations_for_dispatch_key(dispatch_key: str = "") -> None: ...
1658
def _dispatch_get_registrations_for_dispatch_key(
1659
    dispatch_key: str = "",
1660
) -> List[str]: ...
1661
def _are_functorch_transforms_active() -> _bool: ...
1662

1663
# Define in torch/csrc/autograd/init.cpp
1664
def _set_python_dispatcher(dispatcher: object) -> None: ...
1665

1666
def _get_nested_int(id: _int, coeff: _int) -> SymInt: ...
1667

1668
def _get_constant_bool_symnode(val: _bool) -> Any: ...
1669

1670
class _TorchDispatchModeKey(Enum):
1671
    ${torch_dispatch_mode_key_hints}
1672

1673
class _SetExcludeDispatchKeyGuard:
1674
    def __init__(self, k: DispatchKey, enabled: _bool): ...
1675
    def __enter__(self): ...
1676
    def __exit__(self, exc_type, exc_value, traceback): ...
1677

1678
# Defined in torch/csrc/utils/schema_info.h
1679

1680
class _SchemaInfo:
1681
    def __init__(self, schema: _int) -> None: ...
1682

1683
    @overload
1684
    def is_mutable(self) -> _bool: ...
1685
    @overload
1686
    def is_mutable(self, name: str) -> _bool: ...
1687

1688
    def has_argument(self, name: str) -> _bool: ...
1689

1690
# Defined in torch/csrc/utils/init.cpp
1691
class BenchmarkConfig:
1692
    num_calling_threads: _int
1693
    num_worker_threads: _int
1694
    num_warmup_iters: _int
1695
    num_iters: _int
1696
    profiler_output_path: str
1697

1698
class BenchmarkExecutionStats:
1699
    latency_avg_ms: _float
1700
    num_iters: _int
1701

1702
class ThroughputBenchmark:
1703
    def __init__(self, module: Any) -> None: ...
1704
    def add_input(self, *args: Any, **kwargs: Any) -> None: ...
1705
    def run_once(self, *args: Any, **kwargs: Any) -> Any: ...
1706
    def benchmark(self, config: BenchmarkConfig) -> BenchmarkExecutionStats: ...
1707

1708
# Defined in torch/csrc/Storage.cpp
1709
${legacy_storage_base_hints}
1710

1711
# TODO: where
1712
${legacy_class_hints}
1713

1714
# Defined in torch/csrc/autograd/python_engine.cpp
1715
class _ImperativeEngine:
1716
    def queue_callback(self, callback: Callable[[], None]) -> None: ...
1717
    def run_backward(self, *args: Any, **kwargs: Any) -> Tuple[Tensor, ...]: ...
1718
    def is_checkpoint_valid(self) -> _bool: ...
1719

1720
# Defined in torch/csrc/autograd/python_variable.cpp
1721
class _TensorMeta(type): ...
1722

1723
# Defined in torch/csrc/autograd/python_variable.cpp
1724
class TensorBase(metaclass=_TensorMeta):
1725
    requires_grad: _bool
1726
    retains_grad: _bool
1727
    shape: Size
1728
    data: Tensor
1729
    names: List[str]
1730
    device: _device
1731
    dtype: _dtype
1732
    layout: _layout
1733
    real: Tensor
1734
    imag: Tensor
1735
    T: Tensor
1736
    H: Tensor
1737
    mT: Tensor
1738
    mH: Tensor
1739
    ndim: _int
1740
    output_nr: _int
1741
    _version: _int
1742
    _base: Optional[Tensor]
1743
    _cdata: _int
1744
    grad_fn: Optional[_Node]
1745
    _grad_fn: Any
1746
    _grad: Optional[Tensor]
1747
    grad: Optional[Tensor]
1748
    _backward_hooks: Optional[Dict[_int, Callable[[Tensor], Optional[Tensor]]]]
1749
    nbytes: _int
1750
    itemsize: _int
1751
    _has_symbolic_sizes_strides: _bool
1752

1753
    def _view_func_unsafe(
1754
        self,
1755
        new_base: Tensor,
1756
        symint_visitor_fn: Optional[Callable[[_int], _int]] = None,
1757
        tensor_visitor_fn: Optional[Callable[[Tensor], Tensor]] = None
1758
    ):
1759
        ...
1760

1761
    ${tensor_method_hints}
1762

1763
_TensorBase = TensorBase
1764

1765
# Defined in torch/csrc/multiprocessing/init.cpp
1766
def _multiprocessing_init() -> None: ...
1767
def _set_thread_name(name: str) -> None: ...
1768
def _get_thread_name() -> str: ...
1769

1770
# Defined in torch/csrc/Module.cpp
1771
def _accelerator_hooks_device_count() -> _int: ...
1772
def _accelerator_hooks_set_current_device(device_index: _int) -> None: ...
1773
def _accelerator_hooks_get_current_device() -> _int: ...
1774
def _accelerator_hooks_exchange_device(device_index: _int) -> _int: ...
1775
def _accelerator_hooks_maybe_exchange_device(device_index: _int) -> _int: ...
1776
def _get_accelerator(check: _bool = False) -> _device: ...
1777

1778
# Defined in torch/csrc/mtia/Module.cpp
1779
def _mtia_init() -> None: ...
1780
def _mtia_isBuilt() -> _bool: ...
1781
def _mtia_isInBadFork() -> _bool: ...
1782
def _mtia_deviceSynchronize() -> None: ...
1783
def _mtia_getCurrentStream(device: _int) -> Stream: ...
1784
def _mtia_setCurrentStream(stream: Stream) -> None: ...
1785
def _mtia_getDefaultStream(device: _int) -> Stream: ...
1786
def _mtia_memoryStats(device: _int) -> Dict[str, Any]: ...
1787

1788

1789
# Defined in torch/csrc/mps/Module.cpp
1790
def _mps_deviceSynchronize() -> None: ...
1791
def _mps_get_default_generator() -> Generator: ...
1792
def _mps_emptyCache() -> None: ...
1793
def _mps_setMemoryFraction(fraction: _float) -> None: ...
1794
def _mps_currentAllocatedMemory() -> _int: ...
1795
def _mps_driverAllocatedMemory() -> _int: ...
1796
def _mps_recommendedMaxMemory() -> _int: ...
1797
def _mps_is_available() -> _bool: ...
1798
def _mps_is_on_macos_or_newer(major: _int, minor: _int) -> _bool: ...
1799
def _mps_profilerStartTrace(mode: str, wait_until_completed: _bool) -> None: ...
1800
def _mps_profilerStopTrace() -> None: ...
1801
def _mps_acquireEvent(enable_timing: _bool) -> _int: ...
1802
def _mps_releaseEvent(event_id: _int) -> None: ...
1803
def _mps_recordEvent(event_id: _int) -> None: ...
1804
def _mps_waitForEvent(event_id: _int) -> None: ...
1805
def _mps_synchronizeEvent(event_id: _int) -> None: ...
1806
def _mps_queryEvent(event_id: _int) -> _bool: ...
1807
def _mps_elapsedTimeOfEvents(start_event_id: _int, end_event_id: _int) -> _float: ...
1808

1809

1810
# Defined in torch/csrc/cuda/Module.cpp
1811
def _cuda_getCurrentStream(device: _int) -> Tuple: ...
1812
def _cuda_getCurrentRawStream(device: _int) -> _int: ...
1813
def _cuda_getDefaultStream(device: _int) -> Tuple: ...
1814
def _cuda_getCurrentBlasHandle() -> _int: ...
1815
def _cuda_clearCublasWorkspaces() -> None: ...
1816
def _cuda_setDevice(device: _int) -> None: ...
1817
def _cuda_exchangeDevice(device: _int) -> _int: ...
1818
def _cuda_maybeExchangeDevice(device: _int) -> _int: ...
1819
def _cuda_getDevice() -> _int: ...
1820
def _cuda_getDeviceCount() -> _int: ...
1821
def _cuda_set_sync_debug_mode(warn_level: Union[_int, str]) -> None: ...
1822
def _cuda_get_sync_debug_mode() -> _int: ...
1823
def _cuda_sleep(cycles: _int) -> None: ...
1824
def _cuda_synchronize() -> None: ...
1825
def _cuda_ipc_collect() -> None: ...
1826
def _cuda_getArchFlags() -> Optional[str]: ...
1827
def _cuda_init() -> None: ...
1828
def _cuda_setStream(stream_id: _int, device_index: _int, device_type: _int) -> None: ...
1829
def _cuda_getCompiledVersion() -> _int: ...
1830
def _cuda_cudaHostAllocator() -> _int: ...
1831
def _cuda_cudaCachingAllocator_raw_alloc(size: _int, cuda_stream: _int) -> _int: ...
1832
def _cuda_cudaCachingAllocator_raw_delete(ptr: _int) -> None: ...
1833
def _cuda_cudaCachingAllocator_set_allocator_settings(env: str) -> None: ...
1834
def _cuda_beginAllocateToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
1835
def _cuda_beginAllocateCurrentStreamToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
1836
def _cuda_endAllocateCurrentStreamToPool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
1837
def _cuda_releasePool(device: _int, mempool_id: Tuple[_int, _int]) -> None: ...
1838
def _cuda_checkPoolLiveAllocations(device: _int, mempool_id: Tuple[_int, _int], expected_live_allocations: Set) -> _bool: ...
1839
def _cuda_setCheckpointPoolState(device: _int, state: _cuda_CUDAAllocator_AllocatorState,  stale_storages: List[_int], storages_to_add_deleters_to: List[_int]) -> None: ...
1840
def _cuda_setMemoryFraction(fraction: _float, device: _int) -> None: ...
1841
def _cuda_emptyCache() -> None: ...
1842
def _cuda_memoryStats(device: _int) -> Dict[str, Any]: ...
1843
def _cuda_resetAccumulatedMemoryStats(device: _int) -> None: ...
1844
def _cuda_resetPeakMemoryStats(device: _int) -> None: ...
1845
def _cuda_memorySnapshot() -> Dict[str, Any]: ...
1846
def _cuda_record_memory_history_legacy(
1847
    enabled: _bool,
1848
    record_context: _bool,
1849
    record_context_cpp: _bool,
1850
    alloc_trace_max_entries: _int,
1851
    alloc_trace_record_context: _bool,
1852
) -> None: ...
1853
def _cuda_record_memory_history(
1854
    enabled: Optional[str],
1855
    context: Optional[str],
1856
    stacks: str,
1857
    max_entries
1858
) -> None: ...
1859
def _cuda_isHistoryEnabled() -> _bool: ...
1860

1861
def _cuda_getAllocatorBackend() -> str: ...
1862
class _cuda_CUDAAllocator_AllocatorState:
1863
    pass
1864
def _cuda_getCheckpointState(device: _int, mempool: Tuple[_int, _int]) -> _cuda_CUDAAllocator_AllocatorState: ...
1865
def _set_cached_tensors_enabled(enabled: _bool) -> None: ...
1866
def _add_cached_tensor(t: Tensor) -> None: ...
1867
def _remove_cached_tensor(t: Tensor) -> None: ...
1868
def _tensors_data_ptrs_at_indices_equal(tensors: List[Union[Tensor, _int]], ptrs: List[Optional[_int]], indices: List[_int]) -> _bool: ...
1869
def _construct_CUDA_Tensor_From_Storage_And_Metadata(metadata: dict, storage: Storage) -> Tensor: ...
1870
def _storage_Use_Count(storage_ptr: _int) -> _int: ...
1871
def _set_storage_access_error_msg(t: Tensor, s: str) -> None: ...
1872
def _free_And_Remove_DeleterFn(storage_ptr: _int) -> None: ...
1873
def _has_Standard_Deleter(storage_ptr: _int) -> _bool: ...
1874

1875
class _cuda_CUDAAllocator: ...
1876

1877
def _cuda_customAllocator(alloc_fn: _int, free_fn: _int) -> _cuda_CUDAAllocator: ...
1878
def _cuda_changeCurrentAllocator(allocator: _cuda_CUDAAllocator) -> None: ...
1879
def _cuda_getAllocator() -> _cuda_CUDAAllocator: ...
1880
def _cuda_lock_mutex() -> None: ...
1881
def _cuda_unlock_mutex() -> None: ...
1882
def _cuda_canDeviceAccessPeer(device: _int, peer_device: _int) -> _bool: ...
1883
def _cuda_jiterator_compile_and_launch_kernel(
1884
    code_string: str,
1885
    kernel_name: str,
1886
    return_by_ref: _bool,
1887
    num_outputs: _int,
1888
    tensors: Tuple,
1889
    kwargs: Dict[str, Union[_int, _float, _bool]],
1890
) -> Tensor: ...
1891
def _cuda_get_cudnn_benchmark_limit() -> _int: ...
1892
def _cuda_set_cudnn_benchmark_limit(arg: _int) -> None: ...
1893
def _cuda_get_conv_benchmark_empty_cache() -> _bool: ...
1894
def _cudnn_set_conv_benchmark_empty_cache(enable: _bool) -> None: ...
1895
def _nccl_version() -> _int: ...
1896
def _nccl_version_suffix() -> bytes : ...
1897
def _nccl_unique_id() -> bytes: ...
1898
def _nccl_init_rank(nranks: _int, comm_id: bytes, rank: _int) -> object: ...
1899
def _nccl_reduce(
1900
    input: Sequence[Tensor],
1901
    output: Tensor,
1902
    root: _int,
1903
    op: _int,
1904
    streams: Optional[Sequence[_CudaStreamBase]],
1905
    comms: Optional[Sequence[object]],
1906
) -> None: ...
1907
def _nccl_all_reduce(
1908
    input: Sequence[Tensor],
1909
    output: Sequence[Tensor],
1910
    op: _int,
1911
    streams: Optional[Sequence[_CudaStreamBase]],
1912
    comms: Optional[Sequence[object]],
1913
) -> None: ...
1914
def _nccl_broadcast(
1915
    input: Sequence[Tensor],
1916
    root: _int,
1917
    streams: Optional[Sequence[_CudaStreamBase]],
1918
    comms: Optional[Sequence[object]],
1919
) -> None: ...
1920
def _nccl_all_gather(
1921
    input: Sequence[Tensor],
1922
    output: Sequence[Tensor],
1923
    streams: Optional[Sequence[_CudaStreamBase]],
1924
    comms: Optional[Sequence[object]],
1925
) -> None: ...
1926
def _nccl_reduce_scatter(
1927
    input: Sequence[Tensor],
1928
    output: Sequence[Tensor],
1929
    op: _int,
1930
    streams: Optional[Sequence[_CudaStreamBase]],
1931
    comms: Optional[Sequence[object]],
1932
) -> None: ...
1933
def _rocm_is_backward_pass() -> _bool: ...
1934
def _cuda_tunableop_enable(val: _bool) -> None: ...
1935
def _cuda_tunableop_is_enabled() -> _bool: ...
1936
def _cuda_tunableop_tuning_enable(val: _bool) -> None: ...
1937
def _cuda_tunableop_tuning_is_enabled() -> _bool: ...
1938
def _cuda_tunableop_set_max_tuning_duration(duration: _int) -> None: ...
1939
def _cuda_tunableop_get_max_tuning_duration() -> _int: ...
1940
def _cuda_tunableop_set_max_tuning_iterations(iterations: _int) -> None: ...
1941
def _cuda_tunableop_get_max_tuning_iterations() -> _int: ...
1942
def _cuda_tunableop_set_filename(filename: str, insert_device_ordinal: Optional[_bool]) -> None: ...
1943
def _cuda_tunableop_get_filename() -> str: ...
1944
def _cuda_tunableop_write_file(filename: Optional[str]) -> _bool: ...
1945
def _cuda_tunableop_read_file(filename: Optional[str]) -> _bool: ...
1946
def _cuda_tunableop_write_file_on_exit(val: _bool) -> None: ...
1947
def _cuda_tunableop_get_results() -> Tuple[str, str, str, _float]: ...
1948
def _cuda_tunableop_get_validators() -> Tuple[str, str]: ...
1949

1950
class _CudaDeviceProperties:
1951
    name: str
1952
    major: _int
1953
    minor: _int
1954
    multi_processor_count: _int
1955
    total_memory: _int
1956
    is_integrated: _int
1957
    is_multi_gpu_board: _int
1958
    max_threads_per_multi_processor: _int
1959
    gcnArchName: str
1960
    warp_size: _int
1961
    uuid: str
1962
    L2_cache_size: _int
1963

1964
# Functions related to SDPA
1965
class _SDPAParams:
1966
    query: Tensor
1967
    key: Tensor
1968
    value: Tensor
1969
    attn_mask: Optional[Tensor]
1970
    dropout: _float
1971
    is_causal: _bool
1972
    enable_gqa: _bool
1973
    def __init__(
1974
        self,
1975
        query: Tensor,
1976
        key: Tensor,
1977
        value: Tensor,
1978
        attn_mask: Optional[Tensor],
1979
        dropout: _float,
1980
        is_causal: _bool,
1981
        enable_gqa: _bool) -> None: ...
1982

1983
class _SDPBackend(Enum):
1984
    ERROR = -1
1985
    MATH = 0
1986
    FLASH_ATTENTION = 1
1987
    EFFICIENT_ATTENTION = 2
1988
    CUDNN_ATTENTION = 3
1989

1990
def _is_flash_attention_available() -> _bool: ...
1991
def _can_use_cudnn_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
1992
def _can_use_flash_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
1993
def _can_use_mem_efficient_attention(params: _SDPAParams, debug: _bool) -> _bool: ...
1994

1995
# Defined in torch/csrc/cuda/GdsFile.cpp
1996
def _gds_register_buffer(t: Storage) -> None: ...
1997
def _gds_deregister_buffer(t: Storage) -> None: ...
1998
def _gds_register_handle(fd: _int) -> _int: ...
1999
def _gds_deregister_handle(handle: _int) -> None: ...
2000
def _gds_load_storage(handle: _int, s: Storage, offset: _int) -> None: ...
2001
def _gds_save_storage(handle: _int, s: Storage, offset: _int) -> None: ...
2002

2003
# Defined in torch/csrc/cuda/python_comm.cpp
2004
def _broadcast(tensor: Tensor, devices: List[_int]) -> List[Tensor]: ...
2005
def _broadcast_out(tensor: Tensor, out_tensors: List[Tensor]) -> List[Tensor]: ...
2006
def _broadcast_coalesced(
2007
    tensors: List[Tensor],
2008
    devices: List[_int],
2009
    buffer_size: _int,
2010
) -> List[List[Tensor]]: ...
2011
def _scatter(
2012
    tensor: Tensor,
2013
    devices: List[_int],
2014
    chunk_sizes: Optional[List[_int]],
2015
    dim: _int,
2016
    streams: Optional[List[Stream]],
2017
) -> List[Tensor]: ...
2018
def _scatter_out(
2019
    tensor: Tensor,
2020
    out_tensors: List[Tensor],
2021
    dim: _int,
2022
    streams: Optional[List[Stream]],
2023
) -> List[Tensor]: ...
2024
def _gather(
2025
    tensors: List[Tensor],
2026
    dim: _int,
2027
    destination_index: Optional[_int],
2028
) -> Tensor: ...
2029
def _gather_out(tensors: List[Tensor], out_tensor: Tensor, dim: _int) -> Tensor: ...
2030

2031
# Defined in torch/csrc/cuda/Stream.cpp
2032
class _CudaStreamBase(Stream):
2033
    stream_id: _int
2034
    device_index: _int
2035
    device_type: _int
2036

2037
    device: _device
2038
    cuda_stream: _int
2039
    priority: _int
2040

2041
    def __new__(
2042
        self,
2043
        priority: _int = 0,
2044
        stream_id: _int = 0,
2045
        device_index: _int = 0,
2046
        stream_ptr: _int = 0,
2047
    ) -> _CudaStreamBase: ...
2048
    def query(self) -> _bool: ...
2049
    def synchronize(self) -> None: ...
2050
    def priority_range(self) -> Tuple[_int, _int]: ...
2051

2052
# Defined in torch/csrc/cuda/Event.cpp
2053
class _CudaEventBase:
2054
    device: _device
2055
    cuda_event: _int
2056

2057
    def __new__(
2058
        cls,
2059
        enable_timing: _bool = False,
2060
        blocking: _bool = False,
2061
        interprocess: _bool = False,
2062
    ) -> _CudaEventBase: ...
2063
    @classmethod
2064
    def from_ipc_handle(cls, device: _device, ipc_handle: bytes) -> _CudaEventBase: ...
2065
    def record(self, stream: _CudaStreamBase) -> None: ...
2066
    def wait(self, stream: _CudaStreamBase) -> None: ...
2067
    def query(self) -> _bool: ...
2068
    def elapsed_time(self, other: _CudaEventBase) -> _float: ...
2069
    def synchronize(self) -> None: ...
2070
    def ipc_handle(self) -> bytes: ...
2071

2072
# Defined in torch/csrc/cuda/Graph.cpp
2073
class _CUDAGraph:
2074
    def capture_begin(self, pool: Optional[Tuple[_int, _int]] = ..., capture_error_mode: str = "global") -> None: ...
2075
    def capture_end(self) -> None: ...
2076
    def register_generator_state(self, Generator) -> None: ...
2077
    def replay(self) -> None: ...
2078
    def reset(self) -> None: ...
2079
    def pool(self) -> Tuple[_int, _int]: ...
2080
    def enable_debug_mode(self) -> None: ...
2081
    def debug_dump(self, debug_path: str) -> None: ...
2082

2083
# Defined in torch/csrc/cuda/MemPool.cpp
2084
class _MemPool:
2085
    def __init__(self, allocator: Optional[_cuda_CUDAAllocator] = None, is_user_created: _bool = True) -> None: ...
2086
    @property
2087
    def id(self) -> Tuple[_int, _int]: ...
2088
    @property
2089
    def allocator(self) -> Optional[_cuda_CUDAAllocator]: ...
2090

2091
class _MemPoolContext:
2092
    def __init__(self, pool: _MemPool) -> None: ...
2093
    @staticmethod
2094
    def active_pool() -> Optional[_MemPool]: ...
2095

2096
def _cuda_isCurrentStreamCapturing() -> _bool: ...
2097
def _graph_pool_handle() -> Tuple[_int, _int]: ...
2098

2099
# Defined in torch/csrc/xpu/Module.cpp
2100
def _xpu_setDevice(device: _int) -> None: ...
2101
def _xpu_exchangeDevice(device: _int) -> _int: ...
2102
def _xpu_maybeExchangeDevice(device: _int) -> _int: ...
2103
def _xpu_getDevice() -> _int: ...
2104
def _xpu_getDeviceCount() -> _int: ...
2105
def _xpu_init() -> None: ...
2106
def _xpu_setStream(stream_id: _int, device_index: _int, device_type: _int) -> None: ...
2107
def _xpu_getCurrentStream(device: _int) -> Tuple: ...
2108
def _xpu_getCurrentRawStream(device: _int) -> _int: ...
2109
def _xpu_synchronize(device: _int) -> None: ...
2110
def _xpu_emptyCache() -> None: ...
2111

2112
class _XpuDeviceProperties:
2113
    name: str
2114
    platform_name: str
2115
    vendor: str
2116
    driver_version: str
2117
    version: str
2118
    total_memory: _int
2119
    max_compute_units: _int
2120
    gpu_eu_count: _int
2121
    gpu_subslice_count: _int
2122
    max_work_group_size: _int
2123
    max_num_sub_groups: _int
2124
    sub_group_sizes: List[_int]
2125
    has_fp16: _bool
2126
    has_fp64: _bool
2127
    has_atomic64: _bool
2128
    type: str
2129

2130
# Defined in torch/csrc/xpu/Stream.cpp
2131
class _XpuStreamBase(Stream):
2132
    stream_id: _int
2133
    device_index: _int
2134
    device_type: _int
2135

2136
    device: _device
2137
    sycl_queue: _int
2138
    priority: _int
2139

2140
    def __new__(
2141
        cls,
2142
        priority: _int = 0,
2143
        stream_id: _int = 0,
2144
        device_index: _int = 0,
2145
        device_type: _int = 0,
2146
    ) -> _XpuStreamBase: ...
2147
    def query(self) -> _bool: ...
2148
    def synchronize(self) -> None: ...
2149
    @staticmethod
2150
    def priority_range() -> Tuple: ...
2151

2152
# Defined in torch/csrc/xpu/Event.cpp
2153
class _XpuEventBase:
2154
    device: _device
2155
    sycl_event: _int
2156

2157
    def __new__(cls, enable_timing: _bool = False) -> _XpuEventBase: ...
2158
    def record(self, stream: _XpuEventBase) -> None: ...
2159
    def wait(self, stream: _XpuStreamBase) -> None: ...
2160
    def query(self) -> _bool: ...
2161
    def elapsed_time(self, other: _XpuEventBase) -> _float: ...
2162
    def synchronize(self) -> None: ...
2163

2164
# Defined in torch/csrc/DataLoader.cpp
2165
def _set_worker_signal_handlers(
2166
    *arg: Any,
2167
) -> None: ...  # THPModule_setWorkerSignalHandlers
2168
def _set_worker_pids(
2169
    key: _int,
2170
    child_pids: Tuple[_int, ...],
2171
) -> None: ...  # THPModule_setWorkerPIDs
2172
def _remove_worker_pids(loader_id: _int) -> None: ...  # THPModule_removeWorkerPIDs
2173
def _error_if_any_worker_fails() -> None: ...  # THPModule_errorIfAnyWorkerFails
2174

2175
# Defined in torch/csrc/jit/python/python_tracer.cpp
2176
class TracingState:
2177
    def push_scope(self, scope_name: str) -> None: ...
2178
    def pop_scope(self) -> None: ...
2179
    def current_scope(self) -> str: ...
2180
    def set_graph(self, graph: Graph) -> None: ...
2181
    def graph(self) -> Graph: ...
2182

2183
def _create_graph_by_tracing(
2184
    func: Callable[..., Any],
2185
    inputs: Any,
2186
    var_name_lookup_fn: Callable[[Tensor], str],
2187
    strict: Any,
2188
    force_outplace: Any,
2189
    self: Any = None,
2190
    argument_names: List[str] = [],
2191
) -> Tuple[Graph, Stack]: ...
2192
def _tracer_warn_use_python(): ...
2193
def _get_tracing_state() -> TracingState: ...
2194

2195
# Defined in torch/csrc/jit/python/python_ir.cpp
2196
# Not actually defined in python_ir.cpp, not sure where they are.
2197
class IValue: ...
2198

2199
Stack = List[IValue]
2200

2201
class JitType:
2202
    annotation_str: str
2203
    def isSubtypeOf(self, other: JitType) -> _bool: ...
2204
    def with_dtype(self, dtype: _dtype) -> JitType: ...
2205
    def with_sizes(self, sizes: List[Optional[_int]]) -> JitType: ...
2206
    def kind(self) -> str: ...
2207
    def scalarType(self) -> Optional[str]: ...
2208
    def getElementType(self) -> JitType: ...
2209
    def dtype(self) -> Optional[_dtype]: ...
2210

2211
class InferredType:
2212
    def __init__(self, arg: Union[JitType, str]): ...
2213
    def type(self) -> JitType: ...
2214
    def success(self) -> _bool: ...
2215
    def reason(self) -> str: ...
2216

2217
R = TypeVar("R", bound=JitType)
2218

2219
class Type(JitType):
2220
    def str(self) -> _str: ...
2221
    def containedTypes(self) -> List[JitType]: ...
2222
    def dim(self) -> Optional[_int]: ...
2223
    def undefined(self) -> Optional[_bool]: ...
2224
    def sizes(self) -> Optional[List[_int]]: ...
2225
    def symbol_sizes(self) -> Optional[List[_int]]: ...
2226
    def varyingSizes(self) -> Optional[List[Optional[_int]]]: ...
2227
    def strides(self) -> Optional[List[_int]]: ...
2228
    def contiguous(self) -> Self: ...
2229
    def device(self) -> Optional[_device]: ...
2230
    def __eq__(self, other: object) -> _bool: ...
2231
    __hash__ = None  # type: ignore[assignment]
2232
    def is_interface_type(self) -> _bool: ...
2233
    def requires_grad(self) -> _bool: ...
2234
    @property
2235
    def annotation_string(self) -> _str: ...
2236

2237
class AnyType(JitType):
2238
    @staticmethod
2239
    def get() -> AnyType: ...
2240

2241
class NoneType(JitType):
2242
    @staticmethod
2243
    def get() -> NoneType: ...
2244

2245
class BoolType(JitType):
2246
    @staticmethod
2247
    def get() -> BoolType: ...
2248

2249
class FloatType(JitType):
2250
    @staticmethod
2251
    def get() -> FloatType: ...
2252

2253
class ComplexType(JitType):
2254
    @staticmethod
2255
    def get() -> ComplexType: ...
2256

2257
class IntType(JitType):
2258
    @staticmethod
2259
    def get() -> IntType: ...
2260

2261
class SymIntType(JitType):
2262
    @staticmethod
2263
    def get() -> SymIntType: ...
2264

2265
class SymBoolType(JitType):
2266
    @staticmethod
2267
    def get() -> SymBoolType: ...
2268

2269
class NumberType(JitType):
2270
    @staticmethod
2271
    def get() -> NumberType: ...
2272

2273
class StringType(JitType):
2274
    @staticmethod
2275
    def get() -> StringType: ...
2276

2277
class DeviceObjType(JitType):
2278
    @staticmethod
2279
    def get() -> DeviceObjType: ...
2280

2281
class _GeneratorType(JitType):
2282
    @staticmethod
2283
    def get() -> _GeneratorType: ...
2284

2285
class StreamObjType(JitType):
2286
    @staticmethod
2287
    def get() -> StreamObjType: ...
2288

2289
class ListType(JitType):
2290
    def __init__(self, a: JitType) -> None: ...
2291
    def getElementType(self) -> JitType: ...
2292
    @staticmethod
2293
    def ofInts() -> ListType: ...
2294
    @staticmethod
2295
    def ofTensors() -> ListType: ...
2296
    @staticmethod
2297
    def ofFloats() -> ListType: ...
2298
    @staticmethod
2299
    def ofComplexDoubles() -> ListType: ...
2300
    @staticmethod
2301
    def ofBools() -> ListType: ...
2302
    @staticmethod
2303
    def ofStrings() -> ListType: ...
2304

2305
class DictType(JitType):
2306
    def __init__(self, key: JitType, value: JitType) -> None: ...
2307
    def getKeyType(self) -> JitType: ...
2308
    def getValueType(self) -> JitType: ...
2309

2310
class TupleType(JitType):
2311
    def __init__(self, a: List[Optional[JitType]]) -> None: ...
2312
    def elements(self) -> List[JitType]: ...
2313

2314
class UnionType(JitType):
2315
    def __init__(self, a: List[JitType]) -> None: ...
2316

2317
class ClassType(JitType):
2318
    def __init__(self, qualified_name: str) -> None: ...
2319

2320
class InterfaceType(JitType):
2321
    def __init__(self, qualified_name: str) -> None: ...
2322
    def getMethod(self, name: str) -> Optional[FunctionSchema]: ...
2323
    def getMethodNames(self) -> List[str]: ...
2324

2325
class OptionalType(JitType, Generic[R]):
2326
    def __init__(self, a: JitType) -> None: ...
2327
    def getElementType(self) -> JitType: ...
2328
    @staticmethod
2329
    def ofTensor() -> OptionalType: ...
2330

2331
class FutureType(JitType):
2332
    def __init__(self, a: JitType) -> None: ...
2333
    def getElementType(self) -> JitType: ...
2334

2335
class AwaitType(JitType):
2336
    def __init__(self, a: JitType) -> None: ...
2337
    def getElementType(self) -> JitType: ...
2338

2339
class RRefType(JitType):
2340
    def __init__(self, a: JitType) -> None: ...
2341

2342
class EnumType(JitType):
2343
    def __init__(
2344
        self,
2345
        qualified_name: str,
2346
        value_type: JitType,
2347
        enum_names_values: List[Any],
2348
    ) -> None: ...
2349

2350
class TensorType(JitType):
2351
    @classmethod
2352
    def get(cls) -> TensorType: ...
2353
    @classmethod
2354
    def getInferred(cls) -> TensorType: ...
2355
    def with_sizes(self, other: Optional[List[Optional[_int]]]) -> TensorType: ...
2356
    def sizes(self) -> Optional[List[_int]]: ...
2357
    def varyingSizes(self) -> Optional[List[Optional[_int]]]: ...
2358
    def strides(self) -> Optional[List[_int]]: ...
2359
    def device(self) -> Optional[_device]: ...
2360
    def dim(self) -> _int: ...
2361
    def dtype(self) -> Optional[_dtype]: ...
2362
    @staticmethod
2363
    def create_from_tensor(t: Tensor) -> TensorType: ...
2364

2365
# Defined in torch/csrc/jit/python/python_tree_views.cpp
2366
class SourceRange: ...
2367
class TreeView: ...
2368

2369
class Ident(TreeView):
2370
    @property
2371
    def name(self) -> str: ...
2372

2373
class ClassDef(TreeView): ...
2374

2375
class Def(TreeView):
2376
    def name(self) -> Ident: ...
2377

2378
class Decl(TreeView): ...
2379

2380
# Defined in torch/csrc/distributed/rpc/init.cpp
2381
def _rpc_init() -> _bool: ...
2382

2383
# Defined in torch/csrc/distributed/autograd/init.cpp
2384
def _dist_autograd_init() -> _bool: ...
2385

2386
# Defined in torch/csrc/distributed/c10d/init.cpp
2387
def _c10d_init() -> _bool: ...
2388

2389
# Defined in torch/csrc/distributed/rpc/testing/init.cpp
2390
def _faulty_agent_init() -> _bool: ...
2391
def _register_py_class_for_device(device: str, cls: Any) -> None: ...
2392

2393
# Defined in torch/csrc/Module.cpp
2394
def _current_graph_task_id() -> _int: ...
2395
def _current_autograd_node() -> _Node: ...
2396
def _will_engine_execute_node(node: _Node) -> _bool: ...
2397
def _dispatch_key_set(tensor) -> str: ...
2398

2399
# Defined in torch/csrc/Exceptions.cpp
2400
class OutOfMemoryError(RuntimeError): ...
2401
class _DistError(RuntimeError): ...
2402
class _DistBackendError(RuntimeError): ...
2403
class _DistStoreError(RuntimeError): ...
2404
class _DistNetworkError(RuntimeError): ...
2405

2406
# Defined in torch/csrc/profiler/init.cpp
2407
class CapturedTraceback:
2408
    pass
2409
def gather_traceback(python: _bool, script: _bool, cpp: _bool) -> CapturedTraceback: ...
2410
def symbolize_tracebacks(tracebacks: List[CapturedTraceback]) -> List[Dict[str, Any]]: ...
2411

2412
def _load_mobile_module_from_file(filename: str): ...
2413
def _load_mobile_module_from_bytes(bytes_: bytes): ...
2414
def _load_jit_module_from_file(filename: str): ...
2415
def _load_jit_module_from_bytes(bytes_: bytes): ...
2416
def _save_mobile_module(m: LiteScriptModule, filename: str): ...
2417
def _save_jit_module(m: ScriptModule, filename: str, extra_files: Dict[str, Any]): ...
2418
def _save_mobile_module_to_bytes(m: LiteScriptModule) -> bytes: ...
2419
def _save_jit_module_to_bytes(m: ScriptModule,  extra_files: Dict[str, Any]) -> bytes: ...
2420
def _get_module_info_from_flatbuffer(data: bytes): ...
2421
def _jit_resolve_packet(op_name: str, *args, **kwargs) -> str: ...
2422
def _swap_tensor_impl(t1: Tensor, t2: Tensor): ...
2423
def _pickle_save(obj: Any) -> bytes: ...
2424
def _pickle_load_obj(bs: bytes) -> Any: ...
2425

2426
# Defined in torch/csrc/jit/runtime/static/init.cpp
2427
def _jit_to_static_module(graph_or_module: Union[Graph,ScriptModule]) -> Any: ...
2428
def _fuse_to_static_module(graph_or_module: Union[Graph,ScriptModule], min_size: _int) -> Any: ...
2429

2430
# Defined in torch/csrc/fx/node.cpp
2431
class _NodeBase:
2432
    _erased: _bool
2433
    _prev: FxNode
2434
    _next: FxNode
2435

2436
class _NodeIter(Iterator):
2437
    def __init__(self, root: FxNode, reversed: _bool) -> None: ...
2438
    def __iter__(self) -> Iterator[FxNode]: ...
2439
    def __next__(self) -> FxNode: ...
2440

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.