pytorch

Форк
0
/
overrides.py 
1973 строки · 101.2 Кб
1
"""
2
Python implementation of ``__torch_function__``
3

4
While most of the torch API and handling for ``__torch_function__`` happens
5
at the C++ level, some of the torch API is written in Python so we need
6
python-level handling for ``__torch_function__`` overrides as well. The main
7
developer-facing functionality in this file are handle_torch_function and
8
has_torch_function. See torch/functional.py and test/test_overrides.py
9
for usage examples.
10

11
Note
12
----
13
heavily inspired by NumPy's ``__array_function__`` (see:
14
https://github.com/pytorch/pytorch/issues/24015 and
15
https://www.numpy.org/neps/nep-0018-array-function-protocol.html
16
)
17

18
If changing this file in a way that can affect ``__torch_function__`` overhead,
19
please report the benchmarks in ``benchmarks/overrides_benchmark``. See the
20
instructions in the ``README.md`` in that directory.
21
"""
22

23
import __future__  # noqa: F404
24

25
import collections
26
import functools
27
import types
28
import warnings
29
from typing import Dict, Set, List, Any, Callable, Iterable, Type, Tuple
30
from functools import wraps
31
import contextlib
32

33
import torch
34
from torch._C import (
35
    _has_torch_function, _has_torch_function_unary,
36
    _has_torch_function_variadic, _add_docstr,
37
    _push_on_torch_function_stack, _pop_torch_function_stack, _get_function_stack_at, _len_torch_function_stack,
38
    _is_torch_function_mode_enabled)
39

40
__all__ = [
41
    "get_ignored_functions",
42
    "get_overridable_functions",
43
    "get_testing_overrides",
44
    "handle_torch_function",
45
    "has_torch_function",
46
    "resolve_name",
47
    "is_tensor_like",
48
    "is_tensor_method_or_property",
49
    "wrap_torch_function",
50
    "enable_reentrant_dispatch",
51
]
52

53

54
def _disable_user_warnings(
55
        func: Callable, regex: str = '.*is deprecated, please use.*', module: str = 'torch') -> Callable:
56
    """
57
    Decorator that temporarily disables ``UserWarning``s for the given ``module`` if the warning message matches the
58
    given ``regex`` pattern.
59

60
    Arguments
61
    ---------
62
    func : function
63
        Function to disable the warnings for.
64
    regex : str
65
        A regex pattern compilable by ``re.compile``. This is used to match the ``UserWarning`` message.
66
    module : str
67
        The python module to which the filtering should be restricted.
68

69
    Returns
70
    -------
71
    function
72
        The wrapped function.
73
    """
74

75
    @wraps(func)
76
    def wrapper(*args, **kwargs):
77
        with warnings.catch_warnings():
78
            warnings.filterwarnings("ignore", category=UserWarning, message=regex, module=module)
79
            return func(*args, **kwargs)
80
    return wrapper
81

82

83
@functools.lru_cache(None)
84
@_disable_user_warnings
85
def get_ignored_functions() -> Set[Callable]:
86
    """
87
    Return public functions that cannot be overridden by ``__torch_function__``.
88

89
    Returns
90
    -------
91
    Set[Callable]
92
        A tuple of functions that are publicly available in the torch API but cannot
93
        be overridden with ``__torch_function__``. Mostly this is because none of the
94
        arguments of these functions are tensors or tensor-likes.
95

96
    Examples
97
    --------
98
    >>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions()
99
    True
100
    >>> torch.add in torch.overrides.get_ignored_functions()
101
    False
102
    """
103
    Tensor = torch.Tensor
104
    return {
105
        torch.typename,
106
        torch.is_tensor,
107
        torch.is_storage,
108
        torch.set_default_tensor_type,
109
        torch.set_default_device,
110
        torch.get_default_device,
111
        torch.set_rng_state,
112
        torch.get_rng_state,
113
        torch.manual_seed,
114
        torch.initial_seed,
115
        torch.seed,
116
        torch.save,
117
        torch.load,
118
        torch.set_printoptions,
119
        torch.fork,
120
        torch.get_default_dtype,
121
        torch.get_num_interop_threads,
122
        torch.get_num_threads,
123
        torch.init_num_threads,
124
        torch.import_ir_module,
125
        torch.import_ir_module_from_buffer,
126
        torch.is_anomaly_enabled,
127
        torch.is_anomaly_check_nan_enabled,
128
        torch.is_grad_enabled,
129
        torch.merge_type_from_type_comment,
130
        torch.parse_ir,
131
        torch.parse_schema,
132
        torch.parse_type_comment,
133
        torch.set_anomaly_enabled,
134
        torch.set_flush_denormal,
135
        torch.set_num_interop_threads,
136
        torch.set_num_threads,
137
        torch.wait,
138
        torch.as_tensor,
139
        torch.from_numpy,
140
        torch.get_device,
141
        torch.tensor,
142
        torch.default_generator,
143
        torch.has_cuda,
144
        torch.has_cudnn,
145
        torch.has_lapack,
146
        torch.device,
147
        torch.dtype,
148
        torch.finfo,
149
        torch.has_mkl,
150
        torch.has_mps,
151
        torch.has_mkldnn,
152
        torch.has_openmp,
153
        torch.iinfo,
154
        torch.memory_format,
155
        torch.qscheme,
156
        torch.set_grad_enabled,
157
        torch.no_grad,
158
        torch.enable_grad,
159
        torch.inference_mode,
160
        torch.is_inference_mode_enabled,
161
        torch.layout,
162
        torch.align_tensors,
163
        torch.arange,
164
        torch.as_strided,
165
        torch.bartlett_window,
166
        torch.blackman_window,
167
        torch.broadcast_shapes,
168
        torch.can_cast,
169
        torch.compile,
170
        torch.cudnn_affine_grid_generator,
171
        torch.cudnn_batch_norm,
172
        torch.cudnn_convolution,
173
        torch.cudnn_convolution_transpose,
174
        torch.cudnn_convolution_relu,
175
        torch.cudnn_convolution_add_relu,
176
        torch.cudnn_grid_sampler,
177
        torch.cudnn_is_acceptable,
178
        torch.empty,
179
        torch.empty_permuted,
180
        torch.empty_strided,
181
        torch.empty_quantized,
182
        torch.export.dynamic_dim,
183
        torch.export.export,
184
        torch.export.load,
185
        torch.export.register_dataclass,
186
        torch.export.save,
187
        torch.eye,
188
        torch.fft.fftfreq,
189
        torch.fft.rfftfreq,
190
        torch.from_file,
191
        torch.full,
192
        torch.fill,
193
        torch.hamming_window,
194
        torch.hann_window,
195
        torch.kaiser_window,
196
        torch.linspace,
197
        torch.logspace,
198
        torch.mkldnn_adaptive_avg_pool2d,
199
        torch.mkldnn_convolution,
200
        torch.mkldnn_max_pool2d,
201
        torch.mkldnn_max_pool3d,
202
        torch.mkldnn_linear_backward_weights,
203
        torch.mkldnn_rnn_layer,
204
        torch.normal,
205
        torch.ones,
206
        torch.promote_types,
207
        torch.rand,
208
        torch.randn,
209
        torch.randint,
210
        torch.randperm,
211
        torch.range,
212
        torch.result_type,
213
        torch.scalar_tensor,
214
        torch.sparse_coo_tensor,
215
        torch.sparse_compressed_tensor,
216
        torch.sparse_csr_tensor,
217
        torch.sparse_csc_tensor,
218
        torch.sparse_bsr_tensor,
219
        torch.sparse_bsc_tensor,
220
        torch.sym_constrain_range,
221
        torch.sym_constrain_range_for_size,
222
        torch.tril_indices,
223
        torch.triu_indices,
224
        torch.vander,
225
        torch.zeros,
226
        torch._jit_internal.boolean_dispatch,
227
        torch.nn.functional.assert_int_or_pair,
228
        torch.nn.functional.upsample,
229
        torch.nn.functional.upsample_bilinear,
230
        torch.nn.functional.upsample_nearest,
231
        torch.nn.functional.has_torch_function,
232
        torch.nn.functional.has_torch_function_unary,
233
        torch.nn.functional.has_torch_function_variadic,
234
        torch.nn.functional.handle_torch_function,
235
        torch.nn.functional.sigmoid,
236
        torch.nn.functional.hardsigmoid,
237
        torch.nn.functional.tanh,
238
        torch.nn.functional._canonical_mask,
239
        torch.nn.functional._none_or_dtype,
240
        # Doesn't actually take or return tensor arguments
241
        torch.nn.init.calculate_gain,
242
        # These are deprecated; don't test them
243
        torch.nn.init.uniform,
244
        torch.nn.init.normal,
245
        torch.nn.init.constant,
246
        torch.nn.init.eye,
247
        torch.nn.init.dirac,
248
        torch.nn.init.xavier_uniform,
249
        torch.nn.init.xavier_normal,
250
        torch.nn.init.kaiming_uniform,
251
        torch.nn.init.kaiming_normal,
252
        torch.nn.init.orthogonal,
253
        torch.nn.init.sparse,
254
        torch.nested.to_padded_tensor,
255
        has_torch_function,
256
        handle_torch_function,
257
        torch.set_autocast_enabled,
258
        torch.is_autocast_enabled,
259
        torch.clear_autocast_cache,
260
        torch.set_autocast_cpu_enabled,
261
        torch.is_autocast_cpu_enabled,
262
        torch.set_autocast_xla_enabled,
263
        torch.is_autocast_xla_enabled,
264
        torch.set_autocast_ipu_enabled,
265
        torch.is_autocast_ipu_enabled,
266
        torch.set_autocast_cpu_dtype,
267
        torch.get_autocast_cpu_dtype,
268
        torch.set_autocast_ipu_dtype,
269
        torch.get_autocast_ipu_dtype,
270
        torch.get_autocast_gpu_dtype,
271
        torch.set_autocast_gpu_dtype,
272
        torch.get_autocast_xla_dtype,
273
        torch.set_autocast_xla_dtype,
274
        torch.autocast_increment_nesting,
275
        torch.autocast_decrement_nesting,
276
        torch.is_autocast_cache_enabled,
277
        torch.set_autocast_cache_enabled,
278
        torch.nn.functional.hardswish,
279
        torch.is_vulkan_available,
280
        torch.are_deterministic_algorithms_enabled,
281
        torch.use_deterministic_algorithms,
282
        torch.is_deterministic_algorithms_warn_only_enabled,
283
        torch.set_deterministic_debug_mode,
284
        torch.get_deterministic_debug_mode,
285
        torch.set_float32_matmul_precision,
286
        torch.get_float32_matmul_precision,
287
        torch.unify_type_list,
288
        torch.is_warn_always_enabled,
289
        torch.set_warn_always,
290
        torch.vitals_enabled,
291
        torch.set_vital,
292
        torch.read_vitals,
293
        torch.vmap,
294
        torch.cond,
295
        torch.frombuffer,
296
        torch.asarray,
297
        torch._functional_sym_constrain_range,
298
        torch._make_dep_token,
299
        Tensor.__delitem__,
300
        Tensor.__dir__,
301
        Tensor.__getattribute__,
302
        Tensor.__init__,
303
        Tensor.__iter__,
304
        Tensor.__init_subclass__,
305
        Tensor.__delattr__,
306
        Tensor.__setattr__,
307
        Tensor.__torch_function__,
308
        Tensor.__torch_dispatch__,
309
        Tensor.__new__,
310
        Tensor.__class__,
311
        Tensor.__subclasshook__,
312
        Tensor.__hash__,
313
        Tensor.as_subclass,
314
        Tensor.eig,
315
        Tensor.lstsq,
316
        Tensor.reinforce,
317
        Tensor.new,
318
        Tensor.new_tensor,
319
        Tensor.new_empty,
320
        Tensor.new_empty_strided,
321
        Tensor.new_zeros,
322
        Tensor.new_ones,
323
        Tensor.new_full,
324
        Tensor._make_subclass,
325
        Tensor.solve,
326
        Tensor.symeig,
327
        Tensor.stride,
328
        Tensor.unflatten,
329
        Tensor.to_sparse_coo,
330
        Tensor.to_sparse_csr,
331
        Tensor.to_sparse_csc,
332
        Tensor.to_sparse_bsr,
333
        Tensor.to_sparse_bsc,
334
        Tensor._to_sparse,
335
        Tensor._to_sparse_csr,
336
        Tensor._to_sparse_csc,
337
        Tensor._to_sparse_bsr,
338
        Tensor._to_sparse_bsc,
339
        Tensor._typed_storage,
340
        Tensor._reduce_ex_internal,
341
        Tensor._fix_weakref,
342
        Tensor._view_func,
343
        Tensor._view_func_unsafe,
344
        Tensor._rev_view_func_unsafe,
345
        Tensor._make_wrapper_subclass,
346
        Tensor._python_dispatch.__get__,
347
        Tensor._has_symbolic_sizes_strides.__get__,
348
        Tensor._conj,
349
        Tensor._conj_physical,
350
        Tensor._lazy_clone,
351
        Tensor._neg_view,
352
        Tensor._is_zerotensor,
353
        Tensor._is_all_true,
354
        Tensor._is_any_true,
355
        Tensor._addmm_activation,
356
        Tensor.to_padded_tensor,
357
    }
358

359

360
@functools.lru_cache(None)
361
def get_default_nowrap_functions() -> Set[Callable]:
362
    """
363
    Return public functions that do not wrap in a subclass when invoked by
364
    the default ``Tensor.__torch_function__`` that preserves subclasses.  Typically,
365
    these functions represent field accesses (i.e., retrieving a Tensor that
366
    is stored somewhere on the Tensor) as opposed to computation.  Users of
367
    these functions expect object identity to be preserved over multiple accesses
368
    (e.g., ``a.grad is a.grad``) which cannot be upheld if we're wrapping on
369
    the fly every time (furthermore, the tensor stored here might already be
370
    the subclass, in which case wrapping really ought not to happen).
371

372
    Not ALL property accessors have this property; for example ``Tensor.T`` actually
373
    just creates a new transposed tensor on the fly, and so we SHOULD interpose on
374
    these calls (you need to check the implementation of the function to see if
375
    this is the case or not).  Additionally, if a property accessor doesn't return a Tensor,
376
    it doesn't have to be on this list (though it is harmless if it is).
377
    """
378
    Tensor = torch.Tensor
379
    return {
380
        Tensor._base.__get__,
381
        Tensor.grad.__get__,
382
        Tensor._grad.__get__,
383
    }
384

385

386
@functools.lru_cache(None)
387
@_disable_user_warnings
388
def get_testing_overrides() -> Dict[Callable, Callable]:
389
    """Return a dict containing dummy overrides for all overridable functions
390

391
    Returns
392
    -------
393
    Dict[Callable, Callable]
394
        A dictionary that maps overridable functions in the PyTorch API to
395
        lambda functions that have the same signature as the real function
396
        and unconditionally return -1. These lambda functions are useful
397
        for testing API coverage for a type that defines ``__torch_function__``.
398

399
    Examples
400
    --------
401
    >>> import inspect
402
    >>> my_add = torch.overrides.get_testing_overrides()[torch.add]
403
    >>> inspect.signature(my_add)
404
    <Signature (input, other, out=None)>
405
    """
406
    # Every function in the PyTorchAPI that can be overriden needs an entry
407
    # in this dict.
408
    #
409
    # Optimally we would use inspect to get the function signature and define
410
    # the lambda function procedurally but that is blocked by generating
411
    # function signatures for native kernels that can be consumed by inspect.
412
    # See Issue #28233.
413
    Tensor = torch.Tensor
414
    ret: Dict[Callable, Callable] = {
415
        torch.abs: lambda input, out=None: -1,
416
        torch.absolute: lambda input, out=None: -1,
417
        torch.adaptive_avg_pool1d: lambda input, output_size: -1,
418
        torch.adaptive_max_pool1d: lambda inputs, output_size: -1,
419
        torch.acos: lambda input, out=None: -1,
420
        torch.adjoint: lambda input: -1,
421
        torch.arccos: lambda input, out=None: -1,
422
        torch.acosh: lambda input, out=None: -1,
423
        torch.arccosh: lambda input, out=None: -1,
424
        torch.add: lambda input, other, out=None: -1,
425
        torch.addbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1,
426
        torch.addcdiv: lambda input, tensor1, tensor2, value=1, out=None: -1,
427
        torch.addcmul: lambda input, tensor1, tensor2, value=1, out=None: -1,
428
        torch.addmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
429
        torch.addmv: lambda input, mat, vec, beta=1, alpha=1, out=None: -1,
430
        torch.addr: lambda input, vec1, vec2, beta=1, alpha=1, out=None: -1,
431
        torch.affine_grid_generator: lambda theta, size, align_corners: -1,
432
        torch.all: lambda input, dim=None: -1,
433
        torch.allclose: lambda input, other, trol=1e-05, atol=1e-08, equal_nan=False: -1,
434
        torch.alpha_dropout: lambda input, p, train, inplace=False: -1,
435
        torch.amax: lambda input, dim=None: -1,
436
        torch.amin: lambda input, dim=None: -1,
437
        torch.aminmax: lambda input, dim=None, keepdim=False, out=None: -1,
438
        torch.angle: lambda input, out=None: -1,
439
        torch.any: lambda input, dim=None, keepdim=False, out=None: -1,
440
        torch.argmax: lambda input: -1,
441
        torch.argmin: lambda input: -1,
442
        torch.argsort: lambda input, dim=None: -1,
443
        torch.asin: lambda input, out=None: -1,
444
        torch._assert_async: lambda input, msg: -1,
445
        torch.arcsin: lambda input, out=None: -1,
446
        torch.asinh: lambda input, out=None: -1,
447
        torch.arcsinh: lambda input, out=None: -1,
448
        torch.atan: lambda input, out=None: -1,
449
        torch.arctan: lambda input, out=None: -1,
450
        torch.atan2: lambda input, other, out=None: -1,
451
        torch.arctan2: lambda input, other, out=None: -1,
452
        torch.atanh: lambda input, out=None: -1,
453
        torch.arctanh: lambda input, out=None: -1,
454
        torch.atleast_1d: lambda *tensors: -1,
455
        torch.atleast_2d: lambda *tensors: -1,
456
        torch.atleast_3d: lambda *tensors: -1,
457
        torch.avg_pool1d: lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True: -1,
458
        torch.baddbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1,
459
        torch.batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled: -1,
460
        torch.batch_norm_backward_elemt: lambda grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count_tensor: -1,
461
        torch.batch_norm_backward_reduce: lambda grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g: -1,
462
        torch.batch_norm_elemt: lambda input, weight, bias, mean, invstd, eps: -1,
463
        torch.batch_norm_gather_stats: lambda input, mean, invstd, running_mean, running_var, momentum, eps, count: -1,
464
        torch.batch_norm_gather_stats_with_counts: lambda input, mean, invstd, running_mean, running_var, momentum, eps, count: -1,
465
        torch.batch_norm_stats: lambda input, eps: -1,
466
        torch.batch_norm_update_stats: lambda input, running_mean, running_var, momentum: -1,
467
        torch.bernoulli: lambda input, generator=None, out=None: -1,
468
        torch.bilinear: lambda input1, input2, weight, bias: -1,
469
        torch.binary_cross_entropy_with_logits: (lambda input, target, weight=None, size_average=None, reduce=None,
470
                                                 reduction='mean', pos_weight=None: -1),
471
        torch.bincount: lambda input, weights=None, minlength=0: -1,
472
        torch.binomial: lambda count, prob, generator=None: -1,
473
        torch.bitwise_and: lambda input, other, out=None: -1,
474
        torch.bitwise_not: lambda input, out=None: -1,
475
        torch.bitwise_or: lambda input, other, out=None: -1,
476
        torch.bitwise_xor: lambda input, other, out=None: -1,
477
        torch.bitwise_left_shift: lambda input, other, out=None: -1,
478
        torch.bitwise_right_shift: lambda input, other, out=None: -1,
479
        torch.block_diag: lambda *tensors: -1,
480
        torch.bmm: lambda input, mat2, out=None: -1,
481
        torch.broadcast_tensors: lambda *tensors: -1,
482
        torch.broadcast_to: lambda self, size: -1,
483
        torch.bucketize: lambda input, boundaries, out_int32=False, right=False, out=None: -1,
484
        torch.cartesian_prod: lambda *tensors: -1,
485
        torch.cat: lambda tensors, dim=0, out=None: -1,
486
        torch.concat: lambda tensors, dim=0, out=None: -1,  # alias for torch.cat
487
        torch.concatenate: lambda tensors, dim=0, out=None: -1,  # alias for torch.concatenate
488
        torch.cdist: lambda x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary': -1,
489
        torch.ceil: lambda input, out=None: -1,
490
        torch.celu: lambda input, alpha=1., inplace=False: -1,
491
        torch.chain_matmul: lambda *matrices, out=None: -1,
492
        torch.channel_shuffle: lambda input, groups : -1,
493
        torch.cholesky: lambda input, upper=False, out=None: -1,
494
        torch.linalg.cholesky: lambda input, out=None: -1,
495
        torch.linalg.cholesky_ex: lambda input, check_errors=False, out=None: -1,
496
        torch.cholesky_inverse: lambda input, upper=False, out=None: -1,
497
        torch.cholesky_solve: lambda input1, input2, upper=False, out=None: -1,
498
        torch.choose_qparams_optimized: lambda input, numel, n_bins, ratio, bit_width: -1,
499
        torch.chunk: lambda input, chunks, dim=0: -1,
500
        torch.clamp: lambda input, min=None, max=None, out=None: -1,
501
        torch.clip: lambda input, min=None, max=None, out=None: -1,
502
        torch.clamp_min: lambda input, min, out=None: -1,
503
        torch.clamp_max: lambda input, max, out=None: -1,
504
        torch.column_stack: lambda tensors, out=None: -1,
505
        torch.cov: lambda input, correction=1, fweights=None, aweights=None: -1,
506
        torch.clone: lambda input: -1,
507
        torch.combinations: lambda input, r=2, with_replacement=False: -1,
508
        torch.complex: lambda real, imag: -1,
509
        torch.copysign: lambda input, other, out=None: -1,
510
        torch.polar: lambda abs, ang: -1,
511
        torch.linalg.cond: lambda input, ord=None: -1,
512
        torch.conj: lambda input, out=None: -1,
513
        torch.conj_physical: lambda input, out=None: -1,
514
        torch.resolve_conj: lambda input, out=None: -1,
515
        torch.resolve_neg: lambda input, out=None: -1,
516
        torch.constant_pad_nd: lambda input, pad, value=0: -1,
517
        torch.conv1d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
518
        torch.conv2d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
519
        torch.conv3d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
520
        torch.convolution: lambda input, weight, bias, stride, padding, dilation, transposed, output_adding, groups: -1,
521
        torch.conv_tbc: lambda input, weight, bias, pad=0: -1,
522
        torch.conv_transpose1d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
523
        torch.conv_transpose2d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
524
        torch.conv_transpose3d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
525
        torch.corrcoef: lambda input: -1,
526
        torch.cos: lambda input, out=None: -1,
527
        torch.cosine_embedding_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1,
528
        torch.cosh: lambda input, out=None: -1,
529
        torch.cosine_similarity: lambda x1, x2, dim=1, eps=1e-8: -1,
530
        torch.count_nonzero: lambda input: -1,
531
        torch.cross: lambda input, other, dim=None, out=None: -1,
532
        torch.linalg.cross: lambda input, other, dim=-1, out=None: -1,
533
        torch.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean',
534
                         zero_infinity=False: -1),
535
        torch.cummax: lambda input, dim, out=None: -1,
536
        torch.cummin: lambda input, dim, out=None: -1,
537
        torch.cumprod: lambda input, dim, out=None, dtype=None: -1,
538
        torch.cumsum: lambda input, dim, out=None, dtype=None: -1,
539
        torch.cumulative_trapezoid: lambda y, x=None, dim=-1: -1,
540
        torch.logcumsumexp: lambda input, dim, out=None: -1,
541
        torch.deg2rad: lambda input, out=None: -1,
542
        torch.dequantize: lambda input: -1,
543
        torch.det: lambda input: -1,
544
        torch.linalg.det: lambda input: -1,  # alias for torch.det  # type: ignore[attr-defined]
545
        torch.detach: lambda input: -1,
546
        torch.diag: lambda input, diagonal=0, out=None: -1,
547
        torch.diag_embed: lambda input, diagonal=0, out=None: -1,
548
        torch.diagflat: lambda input, offset=0: -1,
549
        torch.diff: lambda input, n=1, dim=-1, prepend=None, append=None, out=None: -1,
550
        torch.diagonal: lambda input, offset=0, dim1=0, dim2=1: -1,
551
        torch.linalg.diagonal: lambda input, offset=0, dim1=-2, dim2=-1: -1,
552
        torch.diagonal_scatter: lambda input, src, offset=0, dim1=0, dim2=1: -1,
553
        torch.as_strided_scatter: lambda self, src, size, stride, storage_offset=None: -1,
554
        torch.digamma: lambda input, out=None: -1,
555
        torch.dist: lambda input, other, p=2: -1,
556
        torch.div: lambda input, other, rounding_mode=None, out=None: -1,
557
        torch.divide: lambda input, other, rounding_mode=None, out=None: -1,
558
        torch.dot: lambda input, other, out=None: -1,
559
        torch.dropout: lambda input, p, train, inplace=False: -1,
560
        torch.dsmm: lambda input, mat2: -1,
561
        torch.hsmm: lambda mat1, mat2: -1,
562
        torch.dsplit: lambda input, indices_or_sections: -1,
563
        torch.dstack: lambda tensors, out=None: -1,
564
        torch.linalg.eig: lambda input, out=None: -1,
565
        torch.linalg.eigvals: lambda input, out=None: -1,
566
        torch.linalg.eigh: lambda input, UPLO="L", out=None: -1,
567
        torch.linalg.eigvalsh: lambda input, UPLO="L", out=None: -1,
568
        torch.einsum: lambda equation, *operands: -1,
569
        torch.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False,
570
                          sparse=False: -1),
571
        torch.embedding_bag: (lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False,
572
                              mode='mean', sparse=False, per_sample_weights=None, padding_idx=None: -1),
573
        torch.empty_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
574
        torch.eq: lambda input, other, out=None: -1,
575
        torch.equal: lambda input, other: -1,
576
        torch.erf: lambda input, out=None: -1,
577
        torch.erfc: lambda input, out=None: -1,
578
        torch.erfinv: lambda input, out=None: -1,
579
        torch.exp: lambda input, out=None: -1,
580
        torch.exp2: lambda input, out=None: -1,
581
        torch.expm1: lambda input, out=None: -1,
582
        torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1,
583
        torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1,
584
        torch.fused_moving_avg_obs_fake_quant: (lambda x, observer_on, fake_quant_on, averaging_const, running_min,
585
                                                running_max, scale, zero_point, quant_min, quant_max, ch_axis,
586
                                                per_row_fake_quant=False, symmetric_quant=False: -1),
587
        torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1,
588
        torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1,
589
        torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1,
590
        torch.fbgemm_linear_int8_weight_fp32_activation: (lambda input, weight, packed, col_offsets, weight_scale,
591
                                                          weight_zero_point, bias: -1),
592
        torch.fbgemm_linear_quantize_weight: lambda input: -1,
593
        torch.fbgemm_pack_gemm_matrix_fp16: lambda input: -1,
594
        torch.fbgemm_pack_quantized_matrix: lambda input, a, b: -1,
595
        torch.feature_alpha_dropout: lambda input, p, train: -1,
596
        torch.feature_dropout: lambda input, p, train: -1,
597
        torch.fft.ifft: lambda input, n=None, dim=-1, norm=None: -1,
598
        torch.fft.rfft: lambda input, n=None, dim=-1, norm=None: -1,
599
        torch.fft.irfft: lambda input, n=None, dim=-1, norm=None: -1,
600
        torch.fft.hfft: lambda input, n=None, dim=-1, norm=None: -1,
601
        torch.fft.ihfft: lambda input, n=None, dim=-1, norm=None: -1,
602
        torch.fft.hfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
603
        torch.fft.ihfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
604
        torch.fft.hfftn: lambda input, s=None, dim=-1, norm=None: -1,
605
        torch.fft.ihfftn: lambda input, s=None, dim=-1, norm=None: -1,
606
        torch.fft.fftn: lambda input, s=None, dim=None, norm=None: -1,
607
        torch.fft.ifftn: lambda input, s=None, dim=None, norm=None: -1,
608
        torch.fft.rfftn: lambda input, s=None, dim=None, norm=None: -1,
609
        torch.fft.irfftn: lambda input, s=None, dim=None, norm=None: -1,
610
        torch.fft.fft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
611
        torch.fft.ifft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
612
        torch.fft.rfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
613
        torch.fft.irfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
614
        torch.fft.fftshift: lambda input, dim=None: -1,
615
        torch.fft.ifftshift: lambda input, dim=None: -1,
616
        torch.fft.fft: lambda input, n=None, dim=-1, norm=None: -1,
617
        torch.fix: lambda input, out=None: -1,
618
        torch.flatten: lambda input, start_dim=0, end_dim=-1: -1,
619
        torch.flip: lambda input, dims: -1,
620
        torch.fliplr: lambda input: -1,
621
        torch.flipud: lambda input: -1,
622
        torch.frobenius_norm: lambda input, dim=None, keepdim=False, out=None: -1,
623
        torch.floor: lambda input, out=None: -1,
624
        torch.floor_divide: lambda input, other: -1,
625
        torch.float_power: lambda input, exponent, out=None: -1,
626
        torch.fmod: lambda input, other, out=None: -1,
627
        torch.frac: lambda input, out=None: -1,
628
        torch.frexp: lambda input, out=None: -1,
629
        torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
630
        torch._functional_assert_async: lambda input, msg, dep_token: -1,
631
        torch.lu_unpack: lambda LU_data, LU_pivots, unpack_data=True, unpack_pivots=True: -1,
632
        torch.gather: lambda input, dim, index, out=None, sparse_grad=False: -1,
633
        torch.gcd: lambda input, other, out=None: -1,
634
        torch.ge: lambda input, other, out=None: -1,
635
        torch.greater_equal: lambda input, other, out=None: -1,
636
        torch.geqrf: lambda input, out=None: -1,
637
        torch.i0: lambda input, out=None: -1,
638
        torch.inner: lambda input, other, out=None: -1,
639
        torch.outer: lambda input, vec2, out=None: -1,
640
        torch.ger: lambda input, vec2, out=None: -1,  # alias for torch.outer
641
        torch.gradient: lambda input, spacing=None, dim=None, edge_order=1: -1,
642
        torch.grid_sampler: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,
643
        torch.grid_sampler_2d: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,
644
        torch.grid_sampler_3d: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,
645
        torch.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05, cudnn_enabled=True: -1,
646
        torch.gru: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
647
        torch.gru_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
648
        torch.gt: lambda input, other, out=None: -1,
649
        torch.greater: lambda input, other, out=None: -1,
650
        torch.hardshrink: lambda input, lambd=0.5: -1,
651
        torch.heaviside: lambda input, values, out=None: -1,
652
        torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction='mean': -1,
653
        torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1,
654
        torch.histogram: lambda input, bins=100, min=None, max=None, weight=None, density=False, out=None: -1,
655
        torch.histogramdd: lambda input, bins, range=None, weight=None, density=False: -1,
656
        torch.linalg.householder_product: lambda input, tau: -1,
657
        torch.hspmm: lambda mat1, mat2, out=None: -1,
658
        torch.hsplit: lambda input, indices_or_sections: -1,
659
        torch.hstack: lambda tensors, out=None: -1,
660
        torch.hypot: lambda input, other, out=None: -1,
661
        torch.igamma: lambda input, other, out=None: -1,
662
        torch.igammac: lambda input, other, out=None: -1,
663
        torch.imag: lambda input, out=None: -1,
664
        torch.index_add: lambda input, dim, index, source: -1,
665
        torch.index_copy: lambda input, dim, index, source: -1,
666
        torch.index_put: lambda input, indices, values, accumulate=False: -1,
667
        torch.index_select: lambda input, dim, index, out=None: -1,
668
        torch.index_fill: lambda input, dim, index, value: -1,
669
        torch.index_reduce: lambda input, dim, index, source, reduce, include_input=True: -1,
670
        torch.isfinite: lambda tensor: -1,
671
        torch.isin: lambda e, te, assume_unique=False, invert=False: -1,
672
        torch.isinf: lambda tensor: -1,
673
        torch.isreal: lambda tensor: -1,
674
        torch.isposinf: lambda input, out=None: -1,
675
        torch.isneginf: lambda input, out=None: -1,
676
        torch.instance_norm: (lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps,
677
                              cudnn_enabled: -1),
678
        torch.int_repr: lambda input: -1,
679
        torch.inverse: lambda input, out=None: -1,
680
        torch.linalg.inv: lambda input, out=None: -1,
681
        torch.linalg.inv_ex: lambda input, check_errors=False, out=None: -1,
682
        torch.is_complex: lambda input: -1,
683
        torch.is_conj: lambda input: -1,
684
        torch.is_neg: lambda input: -1,
685
        torch.is_distributed: lambda input: -1,
686
        torch.is_inference: lambda input: -1,
687
        torch.is_floating_point: lambda input: -1,
688
        torch.is_nonzero: lambda input: -1,
689
        torch.is_same_size: lambda input, other: -1,
690
        torch.is_signed: lambda input: -1,
691
        torch.isclose: lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1,
692
        torch.isnan: lambda input: -1,
693
        torch.istft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True,
694
                      normalized=False, onesided=None, length=None, return_complex=False: -1),
695
        torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1,
696
        torch.kron: lambda input, other: -1,
697
        torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1,
698
        torch.linalg.ldl_factor_ex: lambda input, hermitian=False, check_errors=False, out=None: -1,
699
        torch.linalg.ldl_factor: lambda input, hermitian=False, out=None: -1,
700
        torch.linalg.ldl_solve: lambda LD, pivots, B, hermitian=False, out=None: -1,
701
        torch.layer_norm: lambda input, normalized_shape, weight=None, bias=None, esp=1e-05, cudnn_enabled=True: -1,
702
        torch.lcm: lambda input, other, out=None: -1,
703
        torch.ldexp: lambda input, other, out=None: -1,
704
        torch.le: lambda input, other, out=None: -1,
705
        torch.less_equal: lambda input, other, out=None: -1,
706
        torch.lerp: lambda input, end, weight, out=None: -1,
707
        torch.lgamma: lambda input, out=None: -1,
708
        torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None,
709
        tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1,
710
        torch.log: lambda input, out=None: -1,
711
        torch.log_softmax: lambda input, dim, dtype=None: -1,
712
        torch.log10: lambda input, out=None: -1,
713
        torch.log1p: lambda input, out=None: -1,
714
        torch.log2: lambda input, out=None: -1,
715
        torch.logaddexp: lambda input, other, out=None: -1,
716
        torch.logaddexp2: lambda input, other, out=None: -1,
717
        torch.logdet: lambda input: -1,
718
        torch.xlogy: lambda x, y, out=None: -1,
719
        torch.logical_and: lambda input, other, out=None: -1,
720
        torch.logical_not: lambda input, out=None: -1,
721
        torch.logical_or: lambda input, other, out=None: -1,
722
        torch.logical_xor: lambda input, other, out=None: -1,
723
        torch.logit: lambda input, eps=None: -1,
724
        torch.logsumexp: lambda input, names, keepdim=False, out=None: -1,
725
        torch.lstm: lambda data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional: -1,
726
        torch.lstm_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
727
        torch.lt: lambda input, other, out=None: -1,
728
        torch.less: lambda input, other, out=None: -1,
729
        torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1,
730
        torch.lu_solve: lambda b, LU_data, LU_pivots, out=None: -1,
731
        torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1,  # type: ignore[attr-defined]  # noqa: B950
732
        torch.masked_fill: lambda input, mask, value: -1,
733
        torch.masked_scatter: lambda input, mask, source: -1,
734
        torch.masked_select: lambda input, mask, out=None: -1,
735
        torch.matmul: lambda input, other, out=None: -1,
736
        torch.linalg.lu: lambda input, pivot=True, out=None: -1,
737
        torch.linalg.lu_factor: lambda input, pivot=True, out=None: -1,
738
        torch.linalg.lu_factor_ex: lambda input, pivot=True, check_errors=False, out=None: -1,
739
        torch.linalg.lu_solve: lambda LU, pivots, B, left=True, adjoint=False, out=None: -1,
740
        torch.linalg.matmul: lambda input, other, out=None: -1,  # alias for torch.matmul
741
        torch.matrix_power: lambda input, n: -1,
742
        torch.linalg.matrix_power: lambda input, n, out=None: -1,
743
        torch.linalg.matrix_rank: lambda input, tol=None, hermitian=False: -1,
744
        torch.linalg.multi_dot: lambda tensors, out=None: -1,
745
        torch.matrix_exp: lambda input: -1,
746
        torch.linalg.matrix_exp: lambda input: -1,
747
        torch.max: lambda input, out=None: -1,
748
        torch.maximum: lambda input, other, out=None: -1,
749
        torch.fmax: lambda input, other, out=None: -1,
750
        torch.max_pool1d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
751
        torch.max_pool2d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
752
        torch.max_pool3d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
753
        torch.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
754
                                        return_indices=False, ceil_mode=False: -1),
755
        torch.mean: lambda input, dim=None: -1,
756
        torch.nanmean: lambda input, dim=None, keepdim=False, dtype=None, out=None: -1,
757
        torch.median: lambda input, dim=None: -1,
758
        torch.nanmedian: lambda input, dim=None: -1,
759
        torch.meshgrid: lambda *tensors, **kwargs: -1,
760
        torch.min: lambda input, out=None: -1,
761
        torch.minimum: lambda input, other, out=None: -1,
762
        torch.fmin: lambda input, other, out=None: -1,
763
        torch.miopen_batch_norm: (lambda input, weight, bias, running_mean, running_var, training,
764
                                  exponential_average_factor, epsilon: -1),
765
        torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1,
766
        torch.miopen_convolution_add_relu: lambda input, weight, z, alpha, bias, stride, padding, dilation, groups: -1,
767
        torch.miopen_convolution_relu: lambda input, weight, bias, stride, padding, dilation, groups: -1,
768
        torch.miopen_convolution_transpose: (lambda input, weight, bias, padding, output_padding, stride, dilation,
769
                                             groups, benchmark, deterministic: -1),
770
        torch.miopen_depthwise_convolution: (lambda input, weight, bias, padding, stride, dilation, groups, benchmark,
771
                                             deterministic: -1),
772
        torch.miopen_rnn: (lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first,
773
                           dropout, train, bidirectional, batch_sizes, dropout_state: -1),
774
        torch.mm: lambda input, mat2, out=None: -1,
775
        torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1,
776
        torch.movedim: lambda input, source, destination: -1,
777
        torch.moveaxis: lambda input, source, destination: -1,
778
        torch.msort: lambda input, descending=False, out=None: -1,
779
        torch.mul: lambda input, other, out=None: -1,
780
        torch.multiply: lambda input, other, out=None: -1,
781
        torch.multinomial: lambda input, num_samples, replacement=False, out=None: -1,
782
        torch.mv: lambda input, vec, out=None: -1,
783
        torch.mvlgamma: lambda input, p: -1,
784
        torch.narrow: lambda input, dim, start, length: -1,
785
        torch.nan_to_num: lambda input, nan=0.0, posinf=None, neginf=None, out=None: -1,
786
        torch.native_batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps: -1,
787
        torch._native_batch_norm_legit: lambda input, weight, bias, training, momentum, eps: -1,
788
        torch.native_dropout: lambda input, p, train: -1,
789
        torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
790
        torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1,
791
        torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1,
792
        torch.native_channel_shuffle: lambda input, groups : -1,
793
        torch.ne: lambda input, other, out=None: -1,
794
        torch.not_equal: lambda input, other, out=None: -1,
795
        torch.neg: lambda input, out=None: -1,
796
        torch.negative: lambda input, out=None: -1,
797
        torch.nextafter: lambda input, other, out=None: -1,
798
        torch.nn.functional.adaptive_avg_pool2d: lambda input, output_size: -1,
799
        torch.nn.functional.adaptive_avg_pool3d: lambda input, output_size: -1,
800
        torch.nn.functional.adaptive_max_pool1d: lambda input, output_size, return_indices=False: -1,
801
        torch.nn.functional.adaptive_max_pool1d_with_indices: lambda input, output_size, return_indices=False: -1,
802
        torch.nn.functional.adaptive_max_pool2d: lambda input, output_size, return_indices=False: -1,
803
        torch.nn.functional.adaptive_max_pool2d_with_indices: lambda input, output_size, return_indices=False: -1,
804
        torch.nn.functional.adaptive_max_pool3d: lambda input, output_size, return_indices=False: -1,
805
        torch.nn.functional.adaptive_max_pool3d_with_indices: lambda input, output_size, return_indices=False: -1,
806
        torch.nn.functional.affine_grid: lambda theta, size, align_corners=None: -1,
807
        torch.nn.functional.alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
808
        torch.nn.functional.avg_pool2d: (lambda input, kernel_size, stride=None, padding=0, ceil_mode=False,
809
                                         count_include_pad=True, divisor_override=None: -1),
810
        torch.nn.functional.avg_pool3d: (lambda input, kernel_size, stride=None, padding=0, ceil_mode=False,
811
                                         count_include_pad=True, divisor_override=None: -1),
812
        torch.nn.functional.batch_norm: (lambda input, running_mean, running_var, weight=None, bias=None, training=False,
813
                                         momentum=0.1, eps=1e-05: -1),
814
        torch.nn.functional.bilinear: lambda input1, input2, weight, bias=None: -1,
815
        torch.nn.functional.binary_cross_entropy: (lambda input, target, weight=None, size_average=None, reduce=None,
816
                                                   reduction="mean": -1),
817
        torch.nn.functional.binary_cross_entropy_with_logits: (lambda input, target, weight=None, size_average=None,
818
                                                               reduce=None, reduction="mean", pos_weight=None: -1),
819
        torch.nn.functional.celu: lambda input, alpha=1.0, inplace=False: -1,
820
        torch.nn.functional.cosine_embedding_loss: (lambda input1, input2, target, margin=0, size_average=None,
821
                                                    reduce=None, reduction='mean': -1),
822
        torch.nn.functional.cross_entropy: (lambda input, target, weight=None, size_average=None, ignore_index=-100,
823
                                            reduce=None, reduction="mean", label_smoothing=0.0: -1),
824
        torch.nn.functional.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0,
825
                                       reduction='mean', zero_infinity=False: -1),
826
        torch.nn.functional.dropout: lambda input, p=0.5, training=True, inplace=False: -1,
827
        torch.nn.functional.dropout1d: lambda input, p=0.5, training=True, inplace=False: -1,
828
        torch.nn.functional.dropout2d: lambda input, p=0.5, training=True, inplace=False: -1,
829
        torch.nn.functional.dropout3d: lambda input, p=0.5, training=True, inplace=False: -1,
830
        torch.nn.functional.elu: lambda input, alpha=1.0, inplace=False: -1,
831
        torch.nn.functional.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0,
832
                                        scale_grad_by_freq=False, sparse=False: -1),
833
        torch.nn.functional.embedding_bag: (lambda input, weight, offsets=None, max_norm=None, norm_type=2,
834
                                            scale_grad_by_freq=False, mode='mean', sparse=False, per_sample_weights=None,
835
                                            include_last_offset=False, padding_idx=None: -1),
836
        torch.nn.functional.feature_alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
837
        torch.nn.functional.fold: lambda input, output_size, kernel_size, dilation=1, padding=0, stride=1: -1,
838
        torch.nn.functional.fractional_max_pool2d: (lambda input, kernel_size, output_size=None, output_ratio=None,
839
                                                    return_indices=False, _random_samples=None: -1),
840
        torch.nn.functional.fractional_max_pool2d_with_indices: (
841
            lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
842
            _random_samples=None: -1),
843
        torch.nn.functional.fractional_max_pool3d: (lambda input, kernel_size, output_size=None, output_ratio=None,
844
                                                    return_indices=False, _random_samples=None: -1),
845
        torch.nn.functional.fractional_max_pool3d_with_indices: (
846
            lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
847
            _random_samples=None: -1),
848
        torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction='mean': -1,
849
        torch.nn.functional.gelu: lambda input, approximate='none': -1,
850
        torch.nn.functional.glu: lambda input, dim=-1: -1,
851
        torch.nn.functional.grid_sample: lambda input, grid, mode='bilinear', padding_mode='zeros', align_corners=None: -1,
852
        torch.nn.functional.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1,
853
        torch.nn.functional.gumbel_softmax: lambda logits, tau=1, hard=False, eps=1e-10, dim=-1: -1,
854
        torch.nn.functional.hardshrink: lambda input, lambd=0.5: -1,
855
        torch.nn.functional.hardtanh: lambda input, min_val=-1., max_val=1., inplace=False: -1,
856
        torch.nn.functional.hinge_embedding_loss: (lambda input, target, margin=1.0, size_average=None, reduce=None,
857
                                                   reduction='mean': -1),
858
        torch.nn.functional.instance_norm: (lambda input, running_mean=None, running_var=None, weight=None, bias=None,
859
                                            use_input_stats=True, momentum=0.1, eps=1e-05: -1),
860
        torch.nn.functional.interpolate: (lambda input, size=None, scale_factor=None, mode='nearest', align_corners=None,
861
                                          recompute_scale_factor=None, antialias=False: -1),
862
        torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1,
863
        torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
864
        torch.nn.functional.layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
865
        torch.nn.functional.leaky_relu: lambda input, negative_slope=0.01, inplace=False: -1,
866
        torch.nn.functional.linear: lambda input, weight, bias=None: -1,
867
        torch.nn.functional.local_response_norm: lambda input, size, alpha=0.0001, beta=0.75, k=1.0: -1,
868
        torch.nn.functional.log_softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
869
        torch.nn.functional.logsigmoid: lambda input: -1,
870
        torch.nn.functional.lp_pool1d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
871
        torch.nn.functional.lp_pool2d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
872
        torch.nn.functional.lp_pool3d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
873
        torch.nn.functional.margin_ranking_loss: (lambda input1, input2, target, margin=0, size_average=None,
874
                                                  reduce=None, reduction='mean': -1),
875
        torch.nn.functional.max_pool1d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
876
                                         ceil_mode=False, return_indices=False: -1),
877
        torch.nn.functional.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
878
                                                      return_indices=False, ceil_mode=False: -1),
879
        torch.nn.functional.max_pool2d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
880
                                         ceil_mode=False, return_indices=False: -1),
881
        torch.nn.functional.max_pool2d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
882
                                                      return_indices=False, ceil_mode=False: -1),
883
        torch.nn.functional.max_pool3d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
884
                                         return_indices=False, ceil_mode=False: -1),
885
        torch.nn.functional.max_pool3d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
886
                                                      return_indices=False, ceil_mode=False: -1),
887
        torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
888
        torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
889
        torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
890
        torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
891
        torch.nn.functional.multi_head_attention_forward: (
892
            lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v,
893
            add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None,
894
            need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None,
895
            v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1),
896
        torch.nn.functional.multi_margin_loss: (lambda input, target, p=1, margin=1.0, weight=None, size_average=None,
897
                                                reduce=None, reduction='mean': -1),
898
        torch.nn.functional.multilabel_margin_loss: (lambda input, target, size_average=None, reduce=None,
899
                                                     reduction='mean': -1),
900
        torch.nn.functional.multilabel_soft_margin_loss: (lambda input, target, weight=None, size_average=None,
901
                                                          reduce=None, reduction='mean': -1),
902
        torch.nn.functional.nll_loss: (lambda input, target, weight=None, size_average=None, ignore_index=-100,
903
                                       reduce=None, reduction='mean': -1),
904
        torch.nn.functional.normalize: lambda input, p=2, dim=1, eps=1e-12, out=None: -1,
905
        torch.nn.functional.one_hot: lambda tensor, num_classes=-1: -1,
906
        torch.nn.functional.pad: lambda input, pad, mode='constant', value=0: -1,
907
        torch.nn.functional.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1,
908
        torch.nn.functional.poisson_nll_loss: (lambda input, target, log_input=True, full=False, size_average=None,
909
                                               eps=1e-08, reduce=None, reduction='mean': -1),
910
        torch.nn.functional.prelu: lambda input, weight: -1,
911
        torch.nn.functional.relu: lambda input, inplace=False: -1,
912
        torch.nn.functional.relu6: lambda input, inplace=False: -1,
913
        torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1,
914
        torch.nn.functional.selu: lambda input, inplace=False: -1,
915
        torch.nn.functional.silu: lambda input, inplace=False: -1,
916
        torch.nn.functional.mish: lambda input, inplace=False: -1,
917
        torch.nn.functional.scaled_dot_product_attention: lambda query, key, value, attn_mask=None, dropout_p=0.0: -1,
918
        torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean', beta=1.: -1,
919
        torch.nn.functional.huber_loss: lambda input, target, reduction='mean', delta=1.: -1,
920
        torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
921
        torch.nn.functional.softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
922
        torch.nn.functional.softmin: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
923
        torch.nn.functional.softplus: lambda input, beta=1, threshold=20: -1,
924
        torch.nn.functional.softshrink: lambda input, lambd=0.5: -1,
925
        torch.nn.functional.softsign: lambda input: -1,
926
        torch.nn.functional.tanhshrink: lambda input: -1,
927
        torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1,
928
        torch.nn.functional.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06,
929
                                                  swap=False, size_average=None, reduce=None, reduction='mean': -1),
930
        torch.nn.functional.triplet_margin_with_distance_loss: (lambda anchor, positive, negative, *,
931
                                                                distance_function=None, margin=1.0,
932
                                                                swap=False, reduction='mean': -1),
933
        torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1,
934
        torch.nn.init.uniform_: lambda tensor, a=0., b=1., generator=None: -1,
935
        torch.nn.init.normal_: lambda tensor, mean=0., std=1., generator=None: -1,
936
        torch.nn.init.constant_: lambda tensor, val: -1,
937
        torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', generator=None: -1,
938
        torch.nonzero: lambda input, as_tuple=False: -1,
939
        torch.nonzero_static: lambda input, *, size, fill_value=-1: -1,
940
        torch.argwhere: lambda input: -1,
941
        torch.norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
942
        torch.linalg.norm: lambda input, ord=None, dim=None, keepdim=False, out=None, dtype=None: -1,
943
        torch.linalg.vector_norm: lambda input, ord=2, dim=None, keepdim=False, out=None, dtype=None: -1,
944
        torch.linalg.matrix_norm: lambda input, ord='fro', dim=(-2, -1), keepdim=False, out=None, dtype=None: -1,
945
        torch.norm_except_dim: lambda v, pow=2, dim=0: -1,
946
        torch.nuclear_norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
947
        torch.numel: lambda input: -1,
948
        torch.orgqr: lambda input, tau: -1,
949
        torch.ormqr: lambda input, input2, input3, left=True, transpose=False: -1,
950
        torch.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1,
951
        torch.permute: lambda self, dim: -1,
952
        torch.pca_lowrank: lambda input, q=None, center=True, niter=2: -1,
953
        torch.pdist: lambda input, p=2: -1,
954
        torch.pinverse: lambda input, rcond=1e-15: -1,
955
        torch.linalg.pinv: lambda input, rcond=1e-15, hermitian=False: -1,
956
        torch.pixel_shuffle: lambda input, upscale_factor: -1,
957
        torch.pixel_unshuffle: lambda input, downscale_factor: -1,
958
        torch.poisson: lambda input, generator=None: -1,
959
        torch.poisson_nll_loss: lambda input, target, log_input, full, eps, reduction: -1,
960
        torch.polygamma: lambda input, n, out=None: -1,
961
        torch.positive: lambda input, out=None: -1,
962
        torch.prelu: lambda input, weight: -1,
963
        torch.ones_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
964
        torch.pow: lambda input, exponent, out=None: -1,
965
        torch.prod: lambda input, dtype=None: -1,
966
        torch.put: lambda input, index, source, accumulate=False: -1,
967
        torch.q_per_channel_axis: lambda input: -1,
968
        torch.q_per_channel_scales: lambda input: -1,
969
        torch.q_per_channel_zero_points: lambda input: -1,
970
        torch.q_scale: lambda input: -1,
971
        torch.q_zero_point: lambda input: -1,
972
        torch.qr: lambda input, some=True, out=None: -1,
973
        torch.linalg.qr: lambda input, mode='reduced', out=None: -1,
974
        torch.quantile: lambda input, q, dim=None, keepdim=False, interpolation='linear', out=None: -1,
975
        torch.nanquantile: lambda input, q, dim=None, keepdim=False, interpolation='linear', out=None: -1,
976
        torch.quantize_per_channel: lambda input, scales, zero_points, axis, dtype: -1,
977
        torch.quantize_per_tensor: lambda input, scale, zero_point, dtype: -1,
978
        torch.quantize_per_tensor_dynamic: lambda input, dtype, reduce_range: -1,
979
        torch.quantized_batch_norm: lambda input, weight, bias, mean, var, eps, output_scale, output_zero_point: -1,
980
        torch.quantized_gru_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
981
                                   col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
982

983
        torch.quantized_lstm_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
984
                                    col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
985
        torch.quantized_max_pool1d: (lambda input, kernel_size, stride=tuple(), padding=(0,),
986
                                     dilation=(1,), ceil_mode=False: -1),
987
        torch.quantized_max_pool2d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0),
988
                                     dilation=(1, 1), ceil_mode=False: -1),
989
        torch.quantized_max_pool3d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0, 0),
990
                                     dilation=(1, 1, 1), ceil_mode=False: -1),
991
        torch.quantized_rnn_relu_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
992
                                        col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
993
        torch.quantized_rnn_tanh_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
994
                                        col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
995
        torch.rad2deg: lambda input, out=None: -1,
996
        torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
997
        torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
998
        torch.randn_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
999
        torch.ravel: lambda input: -1,
1000
        torch.real: lambda input, out=None: -1,
1001
        torch.vdot: lambda input, other, out=None: -1,
1002
        torch.linalg.vecdot: lambda input, other, dim=-1, out=None: -1,
1003
        torch.view_as_real: lambda input: -1,
1004
        torch.view_as_complex: lambda input: -1,
1005
        torch.reciprocal: lambda input, out=None: -1,
1006
        torch.relu: lambda input, inplace=False: -1,
1007
        torch.remainder: lambda input, other, out=None: -1,
1008
        torch.renorm: lambda input, p, dim, maxnorm, out=None: -1,
1009
        torch.repeat_interleave: lambda input, dim=None: -1,
1010
        torch.reshape: lambda input, shape: -1,
1011
        torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
1012
        torch.rnn_relu_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
1013
        torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
1014
        torch.rnn_tanh_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
1015
        torch.roll: lambda input, shifts, dims=None: -1,
1016
        torch.rot90: lambda input, k=1, dims=(0, 1): -1,
1017
        torch.round: lambda input, out=None: -1,
1018
        torch.row_stack: lambda tensors, out=None: -1,  # alias for torch.vstack
1019
        torch._rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1),
1020
        torch.rrelu: lambda input, lower=1. / 8, upper=1. / 3, training=False, inplace=False: -1,
1021
        torch.rsqrt: lambda input, out=None: -1,
1022
        torch.rsub: lambda input, other, alpha=1: -1,
1023
        torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
1024
        torch.scatter: lambda input, dim, index, src: -1,
1025
        torch.scatter_add: lambda input, dim, index, src: -1,
1026
        torch.scatter_reduce: lambda input, dim, index, src, reduce, include_self=True: -1,
1027
        torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
1028
        torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1,
1029
        torch.select: lambda input, dim, index: -1,
1030
        torch.select_scatter: lambda input, src, dim, index: -1,
1031
        torch.slice_inverse: lambda input, src, dim=0, start=None, end=None, step=1: -1,
1032
        torch.slice_scatter: lambda input, src, dim=0, start=None, end=None, step=1: -1,
1033
        torch.selu: lambda input, inplace=False: -1,
1034
        torch.sigmoid: lambda input, out=None: -1,
1035
        torch.sign: lambda input, out=None: -1,
1036
        torch.signbit: lambda input, out=None: -1,
1037
        torch.sgn: lambda input, out=None: -1,
1038
        torch.sin: lambda input, out=None: -1,
1039
        torch.sinc: lambda input, out=None: -1,
1040
        torch.sinh: lambda input, out=None: -1,
1041
        torch.slogdet: lambda input: -1,
1042
        torch.linalg.slogdet: lambda input: -1,
1043
        torch.smm: lambda input, mat2: -1,
1044
        torch.spmm: lambda input, mat2: -1,
1045
        torch.softmax: lambda input, dim, dtype=None: -1,
1046
        torch.linalg.solve: lambda A, B, left=True, out=None: -1,
1047
        torch.linalg.solve_ex: lambda A, B, left=True, check_errors=False, out=None: -1,
1048
        torch.sort: lambda input, dim=-1, descending=False, *, stable=False, out=None: -1,
1049
        torch.split: lambda tensor, split_size_or_sections, dim=0: -1,
1050
        torch.split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1,
1051
        torch.sqrt: lambda input, out=None: -1,
1052
        torch.square: lambda input, out=None: -1,
1053
        torch.squeeze: lambda input, dim=None, out=None: -1,
1054
        torch.sspaddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
1055
        torch.stack: lambda tensors, dim=0, out=None: -1,
1056
        torch.std: lambda input, dim=None: -1,
1057
        torch.std_mean: lambda input, dim=None: -1,
1058
        torch.stft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True,
1059
                     pad_mode='reflect', normalized=False, onesided=True, return_complex=None: -1),
1060
        torch.sub: lambda input, other, out=None: -1,
1061
        torch.subtract: lambda input, other, out=None: -1,
1062
        torch.sum: lambda input, dim=None: -1,
1063
        torch.sym_float: lambda input: -1,
1064
        torch.sym_int: lambda input: -1,
1065
        torch.sym_max: lambda a, b: -1,
1066
        torch.sym_min: lambda a, b: -1,
1067
        torch.sym_not: lambda input: -1,
1068
        torch.sym_ite: lambda a, b, c: -1,
1069
        torch._sym_sqrt: lambda input: -1,
1070
        torch._sym_cos: lambda input: -1,
1071
        torch._sym_cosh: lambda input: -1,
1072
        torch._sym_sin: lambda input: -1,
1073
        torch._sym_sinh: lambda input: -1,
1074
        torch._sym_tan: lambda input: -1,
1075
        torch._sym_tanh: lambda input: -1,
1076
        torch._sym_asin: lambda input: -1,
1077
        torch._sym_acos: lambda input: -1,
1078
        torch._sym_atan: lambda input: -1,
1079
        torch.nansum: lambda input, dim=None: -1,
1080
        torch.svd: lambda input, some=True, compute_uv=True, out=None: -1,
1081
        torch.svd_lowrank: lambda input, q=6, niter=2, M=None: -1,
1082
        torch.linalg.svd: lambda input, full_matrices=True, out=None: -1,
1083
        torch.linalg.svdvals: lambda input, out=None: -1,
1084
        torch.swapaxes: lambda input, dim0, dim1: -1,
1085
        torch.swapdims: lambda input, axis0, axis1: -1,
1086
        torch.special.airy_ai: lambda input: -1,
1087
        torch.special.bessel_j0: lambda input: -1,
1088
        torch.special.bessel_j1: lambda input: -1,
1089
        torch.special.bessel_y0: lambda input: -1,
1090
        torch.special.bessel_y1: lambda input: -1,
1091
        torch.special.chebyshev_polynomial_t: lambda input, n, out=None: -1,
1092
        torch.special.chebyshev_polynomial_u: lambda input, n, out=None: -1,
1093
        torch.special.chebyshev_polynomial_v: lambda input, n, out=None: -1,
1094
        torch.special.chebyshev_polynomial_w: lambda input, n, out=None: -1,
1095
        torch.special.digamma: lambda input: -1,
1096
        torch.special.entr: lambda input: -1,
1097
        torch.special.erf: lambda input: -1,
1098
        torch.special.erfc: lambda input: -1,
1099
        torch.special.erfcx: lambda input: -1,
1100
        torch.special.erfinv: lambda input: -1,
1101
        torch.special.exp2: lambda input: -1,
1102
        torch.special.expit: lambda input: -1,
1103
        torch.special.expm1: lambda input: -1,
1104
        torch.special.gammainc: lambda input, other, out=None: -1,
1105
        torch.special.gammaincc: lambda input, other, out=None: -1,
1106
        torch.special.gammaln: lambda input: -1,
1107
        torch.special.hermite_polynomial_h: lambda input, n, out=None: -1,
1108
        torch.special.hermite_polynomial_he: lambda input, n, out=None: -1,
1109
        torch.special.i0: lambda input: -1,
1110
        torch.special.i0e: lambda input: -1,
1111
        torch.special.i1: lambda input: -1,
1112
        torch.special.i1e: lambda input: -1,
1113
        torch.special.laguerre_polynomial_l: lambda input, n, out=None: -1,
1114
        torch.special.legendre_polynomial_p: lambda input, n, out=None: -1,
1115
        torch.special.log1p: lambda input: -1,
1116
        torch.special.log_ndtr: lambda input: -1,
1117
        torch.special.log_softmax: lambda input, dim, dtype=None: -1,
1118
        torch.special.logit: lambda input: -1,
1119
        torch.special.logsumexp: lambda input, dim, keepdim=False, out=None: -1,
1120
        torch.special.modified_bessel_i0: lambda input: -1,
1121
        torch.special.modified_bessel_i1: lambda input: -1,
1122
        torch.special.modified_bessel_k0: lambda input: -1,
1123
        torch.special.modified_bessel_k1: lambda input: -1,
1124
        torch.special.multigammaln: lambda input, p: -1,
1125
        torch.special.ndtr: lambda input: -1,
1126
        torch.special.ndtri: lambda input: -1,
1127
        torch.special.polygamma: lambda input, n, out=None: -1,
1128
        torch.special.psi: lambda input: -1,
1129
        torch.special.round: lambda input: -1,
1130
        torch.special.scaled_modified_bessel_k0: lambda input: -1,
1131
        torch.special.scaled_modified_bessel_k1: lambda input: -1,
1132
        torch.special.shifted_chebyshev_polynomial_t: lambda input, n, out=None: -1,
1133
        torch.special.shifted_chebyshev_polynomial_u: lambda input, n, out=None: -1,
1134
        torch.special.shifted_chebyshev_polynomial_v: lambda input, n, out=None: -1,
1135
        torch.special.shifted_chebyshev_polynomial_w: lambda input, n, out=None: -1,
1136
        torch.special.sinc: lambda input: -1,
1137
        torch.special.softmax: lambda input, dim, dtype=None: -1,
1138
        torch.special.spherical_bessel_j0: lambda input: -1,
1139
        torch.special.xlog1py: lambda input, other, out=None: -1,
1140
        torch.special.xlogy: lambda input, other, out=None: -1,
1141
        torch.special.zeta: lambda self, other, out=None: -1,
1142
        torch.t: lambda input: -1,
1143
        torch.take: lambda input, index: -1,
1144
        torch.take_along_dim: lambda input, indices, dim=None, out=None: -1,
1145
        torch.tan: lambda input, out=None: -1,
1146
        torch.tanh: lambda input, out=None: -1,
1147
        torch.linalg.tensorinv: lambda a, ind=2: -1,
1148
        torch.linalg.tensorsolve: lambda a, b, dims=None: -1,
1149
        torch.tensordot: lambda a, b, dims=2, out=None: -1,
1150
        torch.tensor_split: lambda input, indices_or_sections, dim=0: -1,
1151
        torch.threshold: lambda input, threshold, value, inplace=False: -1,
1152
        torch.tile: lambda input, dims: -1,
1153
        torch.topk: lambda input, k, dim=-1, descending=False, out=None: -1,
1154
        torch.trace: lambda input: -1,
1155
        torch.transpose: lambda input, dim0, dim1: -1,
1156
        torch.trapz: lambda y, x=None, dim=-1: -1,
1157
        torch.trapezoid: lambda y, x=None, dim=-1: -1,
1158
        torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1,
1159
        torch.linalg.solve_triangular: lambda input, B, upper, left=True, unitriangular=False: -1,
1160
        torch.tril: lambda input, diagonal=0, out=None: -1,
1161
        torch.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False,
1162

1163
                                    size_average=None, reduce=None, reduction='mean': -1),
1164
        torch.triu: lambda input, diagonal=0, out=None: -1,
1165
        torch.true_divide: lambda input, other: -1,
1166
        torch.trunc: lambda input, out=None: -1,
1167
        torch.unbind: lambda input, dim=0: -1,
1168
        torch.unflatten: lambda input, dim, sizes, names: -1,
1169
        torch.unique: lambda input, sorted=True, return_inverse=False, return_counts=False, dim=None: -1,
1170
        torch.unique_consecutive: lambda input, return_inverse=False, return_counts=False, dim=None: -1,
1171
        torch.unravel_index: lambda indices, shape: -1,
1172
        torch.unsafe_chunk: lambda input, chunks, dim=0: -1,
1173
        torch.unsafe_split: lambda tensor, split_size_or_sections, dim=0: -1,
1174
        torch.unsafe_split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1,
1175
        torch.unsqueeze: lambda input, dim, out=None: -1,
1176
        torch.linalg.vander: lambda x, N=None: -1,
1177
        torch.var: lambda input, dim=None: -1,
1178
        torch.var_mean: lambda input, dim=None: -1,
1179
        torch.vsplit: lambda input, indices_or_sections: -1,
1180
        torch.vstack: lambda tensors, out=None: -1,
1181
        torch.where: lambda condition, x=None, y=None: -1,
1182
        torch.zeros_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
1183
        torch._fw_primal_copy: lambda self, level: -1,
1184
        torch._make_dual_copy: lambda primal, tangent, level: -1,
1185
        torch.view_as_real_copy: lambda self: -1,
1186
        torch.view_as_complex_copy: lambda self: -1,
1187
        torch._conj_copy: lambda self: -1,
1188
        torch._neg_view_copy: lambda self: -1,
1189
        torch.as_strided_copy: lambda self, size, stride, storage_offset=None: -1,
1190
        torch._sparse_broadcast_to_copy: lambda self, size: -1,
1191
        torch.diagonal_copy: lambda self, offset=0, dim1=0, dim2=1: -1,
1192
        torch.expand_copy: lambda self, size, *, implicit=False: -1,
1193
        torch.narrow_copy: lambda self, dim, start, length: -1,
1194
        torch.permute_copy: lambda self, dims: -1,
1195
        torch._reshape_alias_copy: lambda self, size, stride: -1,
1196
        torch.select_copy: lambda self, dim, index: -1,
1197
        torch.detach_copy: lambda self: -1,
1198
        torch.slice_copy: lambda self, dim=0, start=None, end=None, step=1: -1,
1199
        torch.split_copy: lambda self, split_size, dim=0: -1,
1200
        torch.split_with_sizes_copy: lambda self, split_sizes, dim=0: -1,
1201
        torch.squeeze_copy: lambda self, dim: -1,
1202
        torch.t_copy: lambda self: -1,
1203
        torch.transpose_copy: lambda self, dim0, dim1: -1,
1204
        torch.unsqueeze_copy: lambda self, dim: -1,
1205
        torch._indices_copy: lambda self: -1,
1206
        torch._values_copy: lambda self: -1,
1207
        torch.indices_copy: lambda self: -1,
1208
        torch.values_copy: lambda self: -1,
1209
        torch.crow_indices_copy: lambda self: -1,
1210
        torch.col_indices_copy: lambda self: -1,
1211
        torch.ccol_indices_copy: lambda self: -1,
1212
        torch.row_indices_copy: lambda self: -1,
1213
        torch.unbind_copy: lambda self, dim=0: -1,
1214
        torch.view_copy: lambda self, dtype: -1,
1215
        torch.unfold_copy: lambda self, dimension, size, step: -1,
1216
        torch.alias_copy: lambda self: -1,
1217
        Tensor.__floordiv__: lambda self, other: -1,
1218
        Tensor.__rfloordiv__: lambda self, other: -1,
1219
        Tensor.__ifloordiv__: lambda self, other: -1,
1220
        Tensor.__truediv__: lambda self, other: -1,
1221
        Tensor.__rtruediv__: lambda self, other: -1,
1222
        Tensor.__itruediv__: lambda self, other: -1,
1223
        Tensor.__lshift__: lambda self, other: -1,
1224
        Tensor.__rlshift__: lambda self, other: -1,
1225
        Tensor.__ilshift__: lambda self, other: -1,
1226
        Tensor.__rshift__: lambda self, other: -1,
1227
        Tensor.__rrshift__: lambda self, other: -1,
1228
        Tensor.__irshift__: lambda self, other: -1,
1229
        Tensor.__and__: lambda self, other: -1,
1230
        Tensor.__or__: lambda self, other: -1,
1231
        Tensor.__xor__: lambda self, other: -1,
1232
        Tensor.__float__: lambda self: -1,
1233
        Tensor.__complex__: lambda self: -1,
1234
        Tensor.__array__: lambda self, dtype: -1,
1235
        Tensor.__bool__: lambda self: -1,
1236
        Tensor.__contains__: lambda self, other: -1,
1237
        Tensor.__neg__: lambda self: -1,
1238
        Tensor.__invert__: lambda self: -1,
1239
        Tensor.__mod__: lambda self, other: -1,
1240
        Tensor.__rmod__: lambda self, other: -1,
1241
        Tensor.__imod__: lambda self, other: -1,
1242
        Tensor.__array_wrap__: lambda self, array: -1,
1243
        Tensor.__getitem__: lambda self, idx: -1,
1244
        Tensor.__deepcopy__: lambda self, memo: -1,
1245
        Tensor.__int__: lambda self: -1,
1246
        Tensor.__long__: lambda self: -1,
1247
        Tensor.__index__: lambda self: -1,
1248
        Tensor.__len__: lambda self: -1,
1249
        Tensor.__format__: lambda self, format_spec: -1,
1250
        Tensor.__reduce_ex__: lambda self, proto: -1,
1251
        Tensor.__reversed__: lambda self: -1,
1252
        Tensor.__repr__: lambda self, *, tensor_contents=None: -1,
1253
        Tensor.__setitem__: lambda self, k, v: -1,
1254
        Tensor.__setstate__: lambda self, d: -1,
1255
        Tensor.T.__get__: lambda self: -1,
1256
        Tensor.H.__get__: lambda self: -1,
1257
        Tensor.mT.__get__: lambda self: -1,
1258
        Tensor.mH.__get__: lambda self: -1,
1259
        Tensor._backward_hooks.__get__: lambda self: -1,
1260
        Tensor._post_accumulate_grad_hooks.__get__: lambda self: -1,
1261
        Tensor._base.__get__: lambda self: -1,
1262
        Tensor._cdata.__get__: lambda self: -1,
1263
        Tensor.grad.__get__: lambda self: -1,
1264
        Tensor._grad.__get__: lambda self: -1,
1265
        Tensor._grad_fn.__get__: lambda self: -1,
1266
        Tensor.grad_fn.__get__: lambda self: -1,
1267
        Tensor._version.__get__: lambda self: -1,
1268
        Tensor._autocast_to_reduced_precision: lambda self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype: -1,
1269
        Tensor._autocast_to_full_precision: lambda self, cuda_enabled, cpu_enabled: -1,
1270
        Tensor.data.__get__: lambda self: -1,
1271
        Tensor.device.__get__: lambda self: -1,
1272
        Tensor.dtype.__get__: lambda self: -1,
1273
        Tensor.is_cuda.__get__: lambda self: -1,
1274
        Tensor.is_cpu.__get__: lambda self: -1,
1275
        Tensor.is_xla.__get__: lambda self: -1,
1276
        Tensor.is_xpu.__get__: lambda self: -1,
1277
        Tensor.is_ipu.__get__: lambda self: -1,
1278
        Tensor.is_leaf.__get__: lambda self: -1,
1279
        Tensor.retains_grad.__get__: lambda self: -1,
1280
        Tensor.is_meta.__get__: lambda self: -1,
1281
        Tensor.is_mps.__get__: lambda self: -1,
1282
        Tensor.is_mtia.__get__: lambda self: -1,
1283
        Tensor.is_nested.__get__: lambda self: -1,
1284
        Tensor.is_ort.__get__: lambda self: -1,
1285
        Tensor.is_mkldnn.__get__: lambda self: -1,
1286
        Tensor.is_quantized.__get__: lambda self: -1,
1287
        Tensor.is_sparse.__get__: lambda self: -1,
1288
        Tensor.is_sparse_csr.__get__: lambda self: -1,
1289
        Tensor.is_vulkan.__get__: lambda self: -1,
1290
        Tensor.itemsize.__get__: lambda self: -1,
1291
        Tensor.layout.__get__: lambda self: -1,
1292
        Tensor.name.__get__: lambda self: -1,
1293
        Tensor.names.__get__: lambda self: -1,
1294
        Tensor.nbytes.__get__: lambda self: -1,
1295
        Tensor.ndim.__get__: lambda self: -1,
1296
        Tensor.output_nr.__get__: lambda self: -1,
1297
        Tensor.requires_grad.__get__: lambda self: -1,
1298
        Tensor.shape.__get__: lambda self: -1,
1299
        Tensor.volatile.__get__: lambda self: -1,
1300
        Tensor.real.__get__: lambda self: -1,
1301
        Tensor.imag.__get__: lambda self: -1,
1302
        Tensor.__cuda_array_interface__.__get__: lambda self: -1,
1303
        Tensor.type: lambda self, dtype=None, non_blocking=False, **kwargs: -1,
1304
        Tensor._dimI: lambda self: -1,
1305
        Tensor._dimV: lambda self: -1,
1306
        Tensor._indices: lambda self: -1,
1307
        Tensor._is_view: lambda self: -1,
1308
        Tensor._nnz: lambda self: -1,
1309
        Tensor.crow_indices: lambda self: -1,
1310
        Tensor.col_indices: lambda self: -1,
1311
        Tensor.ccol_indices: lambda self: -1,
1312
        Tensor.row_indices: lambda self: -1,
1313
        Tensor._update_names: lambda self, names, inplace: -1,
1314
        Tensor._values: lambda self: -1,
1315
        Tensor.adjoint: lambda self: -1,
1316
        Tensor.align_as: lambda self, other: -1,
1317
        Tensor.align_to: lambda self, order, ellipsis_idx: -1,
1318
        Tensor.apply_: lambda self, callable: -1,
1319
        Tensor.as_strided: lambda self, size, stride: -1,
1320
        Tensor.as_strided_: lambda self, size, stride: -1,
1321
        Tensor.backward: lambda self, gradient=None, retain_graph=None, create_graph=False, inputs=None: -1,
1322
        Tensor.bfloat16: lambda self, memory_format=torch.preserve_format: -1,
1323
        Tensor.bool: lambda self, memory_format=torch.preserve_format: -1,
1324
        Tensor.byte: lambda self, memory_format=torch.preserve_format: -1,
1325
        Tensor.char: lambda self, memory_format=torch.preserve_format: -1,
1326
        Tensor.cauchy_: lambda self, median=0, sigma=1, *, generator=None: -1,
1327
        Tensor.coalesce: lambda self: -1,
1328
        Tensor._coalesced_: lambda self, coalesced: -1,
1329
        Tensor.contiguous: lambda self, memory_format=torch.contiguous_format: -1,
1330
        Tensor.copy_: lambda self, src, non_blocking=False: -1,
1331
        Tensor.cpu: lambda self, memory_format=torch.preserve_format: -1,
1332
        Tensor.cuda: lambda self, memory_format=torch.preserve_format: -1,
1333
        Tensor.xpu: lambda self, memory_format=torch.preserve_format: -1,
1334
        Tensor.ipu: lambda self, memory_format=torch.preserve_format: -1,
1335
        Tensor.data_ptr: lambda self: -1,
1336
        Tensor.dense_dim: lambda self: -1,
1337
        Tensor.diagonal_scatter: lambda self, src, offset=0, dim1=0, dim2=1: -1,
1338
        Tensor.dim: lambda self: -1,
1339
        Tensor.dim_order: lambda self: -1,
1340
        Tensor.double: lambda self, memory_format=torch.preserve_format: -1,
1341
        Tensor.cdouble: lambda self, memory_format=torch.preserve_format: -1,
1342
        Tensor.element_size: lambda self: -1,
1343
        Tensor.expand: lambda self, size: -1,
1344
        Tensor.expand_as: lambda self, other: -1,
1345
        Tensor.exponential_: lambda self, lambd=1, *, generator=None: -1,
1346
        Tensor.fill_: lambda self, value: -1,
1347
        Tensor.fill_diagonal_: lambda self, value: -1,
1348
        Tensor.float: lambda self, memory_format=torch.preserve_format: -1,
1349
        Tensor.cfloat: lambda self, memory_format=torch.preserve_format: -1,
1350
        Tensor.geometric_: lambda self, p, *, generator=None: -1,
1351
        Tensor.get_device: lambda self: -1,
1352
        Tensor.half: lambda self, memory_format=torch.preserve_format: -1,
1353
        Tensor.chalf: lambda self, memory_format=torch.preserve_format: -1,
1354
        Tensor.has_names: lambda self: -1,
1355
        Tensor.indices: lambda self: -1,
1356
        Tensor.int: lambda self, memory_format=torch.preserve_format: -1,
1357
        Tensor.is_coalesced: lambda self: -1,
1358
        Tensor.is_contiguous: lambda self: -1,
1359
        Tensor.is_inference: lambda self: -1,
1360
        Tensor.is_pinned: lambda self: -1,
1361
        Tensor.is_set_to: lambda self, tensor: -1,
1362
        Tensor.is_shared: lambda self: -1,
1363
        Tensor.item: lambda self: -1,
1364
        Tensor.log_normal_: lambda self, mean=1, std=2, *, generator=None: -1,
1365
        Tensor.log_softmax: lambda self, dim: -1,
1366
        Tensor.long: lambda self, memory_format=torch.preserve_format: -1,
1367
        Tensor.map_: lambda self, tensor, callable: -1,
1368
        Tensor.map2_: lambda self, x, y, callable: -1,
1369
        Tensor.mm: lambda self, mat2: -1,
1370
        Tensor.module_load: lambda self, other: -1,
1371
        Tensor.narrow_copy: lambda self, dimension, start, length: -1,
1372
        Tensor.ndimension: lambda self: -1,
1373
        Tensor.nelement: lambda self: -1,
1374
        Tensor._nested_tensor_size: lambda self: -1,
1375
        Tensor._nested_tensor_storage_offsets: lambda self: -1,
1376
        Tensor._nested_tensor_strides: lambda self: -1,
1377
        Tensor.normal_: lambda self: -1,
1378
        Tensor.numpy: lambda self: -1,
1379
        Tensor.permute: lambda self, dim: -1,
1380
        Tensor.pin_memory: lambda self: -1,
1381
        Tensor.put_: lambda self, indices, tensor, accumulate=False: -1,
1382
        Tensor.qscheme: lambda self: -1,
1383
        Tensor.random_: lambda self, from_=0, to=None, *, generator=None: -1,
1384
        Tensor.record_stream: lambda self, stream: -1,
1385
        Tensor.refine_names: lambda self, names: -1,
1386
        Tensor.register_hook: lambda self, hook: -1,
1387
        Tensor.register_post_accumulate_grad_hook: lambda self, hook: -1,
1388
        Tensor.rename: lambda self, name: -1,
1389
        Tensor.repeat: lambda self, *size: -1,
1390
        Tensor.requires_grad_: lambda self, requires_grad=True: -1,
1391
        Tensor.reshape_as: lambda self, other: -1,
1392
        Tensor.resize: lambda self, *size: -1,
1393
        Tensor.resize_: lambda self, size: -1,
1394
        Tensor.resize_as: lambda self, other: -1,
1395
        Tensor.resize_as_sparse_: lambda self, other: -1,
1396
        Tensor.retain_grad: lambda self: -1,
1397
        Tensor.set_: lambda self, source=None, storage_offset=0, size=None, stride=None: -1,
1398
        Tensor.select_scatter: lambda self, src, dim, index: -1,
1399
        Tensor.share_memory_: lambda self: -1,
1400
        Tensor.short: lambda self, memory_format=torch.preserve_format: -1,
1401
        Tensor.size: lambda self: -1,
1402
        Tensor.slice_scatter: lambda self, src, dim=0, start=None, end=None, step=1: -1,
1403
        Tensor.sparse_dim: lambda self: -1,
1404
        Tensor.sparse_mask: lambda self, mask: -1,
1405
        Tensor._sparse_mask_projection: lambda self, mask, accumulate_matches=False: -1,
1406
        Tensor.sparse_resize_: lambda self, size1, size2, dense_dim: -1,
1407
        Tensor.sparse_resize_and_clear_: lambda self, size1, size2, dense_dim: -1,
1408
        Tensor.sspaddmm: lambda self, mat1, mat2, beta=1, alpha=1, out=None: -1,
1409
        Tensor.storage: lambda self: -1,
1410
        Tensor.untyped_storage: lambda self: -1,
1411
        Tensor.storage_offset: lambda self: -1,
1412
        Tensor.storage_type: lambda self: -1,
1413
        Tensor.sum_to_size: lambda self, size: -1,
1414
        Tensor.tile: lambda self, *reps: -1,
1415
        Tensor.to: lambda self, dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format: -1,
1416
        Tensor.to_dense: lambda self, dtype=None, *, masked_grad=None: -1,
1417
        Tensor._to_dense: lambda self, dtype=None, masked_grad=None: -1,
1418
        Tensor.to_sparse: lambda self: -1,
1419
        Tensor.tolist: lambda self: -1,
1420
        Tensor.to_mkldnn: lambda self: -1,
1421
        Tensor.type_as: lambda self, other: -1,
1422
        Tensor.unfold: lambda self, dimension, size, step: -1,
1423
        Tensor.uniform_: lambda self, from_=0, to=1: -1,
1424
        Tensor.values: lambda self: -1,
1425
        Tensor.view: lambda self, shape: -1,
1426
        Tensor.view_as: lambda self, other: -1,
1427
        Tensor.zero_: lambda self: -1,
1428
        Tensor.__dlpack__: lambda self, stream=None: -1,
1429
        Tensor.__dlpack_device__: lambda self: -1,
1430
        torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1,
1431
    }
1432

1433
    ret2 = {}
1434
    ignored = get_ignored_functions()
1435

1436
    for k, v in ret.items():
1437
        # Generate methods like __add__ and add_ by default from add
1438
        names = [
1439
            k.__name__,  # Default method
1440
            k.__name__ + "_",  # Inplace variant
1441
            "__" + k.__name__ + "__",  # Dunder method
1442
            "__i" + k.__name__ + "__",  # Inplace dunder method
1443
            "__r" + k.__name__ + "__",  # Reverse dunder method
1444
        ]
1445

1446
        if k.__name__.startswith("bitwise_"):
1447
            # bitwise_<op> have dunder methods of the form __<op>__
1448
            # And so on.
1449
            subname = k.__name__[len("bitwise_"):]
1450
            names.extend([
1451
                "__" + subname + "__",
1452
                "__i" + subname + "__",
1453
                "__r" + subname + "__"
1454
            ])
1455

1456
        for name in names:
1457
            func = getattr(Tensor, name, None)
1458
            if callable(func) and func not in ret and func not in ignored:
1459
                ret2[func] = v
1460

1461
    ret.update(ret2)
1462
    return ret
1463

1464
def wrap_torch_function(dispatcher: Callable):
1465
    """Wraps a given function with ``__torch_function__`` -related functionality.
1466

1467
    Parameters
1468
    ----------
1469
    dispatcher: Callable
1470
        A callable that returns an iterable of Tensor-likes passed into the function.
1471

1472
    Note
1473
    ----
1474
    This decorator may reduce the performance of your code. Generally, it's enough to express
1475
    your code as a series of functions that, themselves, support __torch_function__. If you
1476
    find yourself in the rare situation where this is not the case, e.g. if you're wrapping a
1477
    low-level library and you also need it to work for Tensor-likes, then this function is available.
1478

1479
    Examples
1480
    --------
1481
    >>> def dispatcher(a): # Must have the same signature as func
1482
    ...     return (a,)
1483
    >>> @torch.overrides.wrap_torch_function(dispatcher)
1484
    >>> def func(a): # This will make func dispatchable by __torch_function__
1485
    ...     return a + 0
1486
    """
1487
    def inner(func):
1488
        @functools.wraps(func)
1489
        def wrapped(*args, **kwargs):
1490
            relevant_args = dispatcher(*args, **kwargs)
1491
            if has_torch_function(relevant_args):
1492
                return handle_torch_function(wrapped, relevant_args, *args, **kwargs)
1493

1494
            return func(*args, **kwargs)
1495

1496
        return wrapped
1497

1498
    return inner
1499

1500
def _get_overloaded_args(relevant_args: Iterable[Any], get_type_fn: Callable[[Any], Type] = None) -> List[Any]:
1501
    """Returns a list of arguments on which to call __torch_function__.
1502

1503
    Checks arguments in relevant_args for __torch_function__ implementations,
1504
    storing references to the arguments and their types in overloaded_args and
1505
    overloaded_types in order of calling precedence. Only distinct types are
1506
    considered. If a type is a subclass of another type it will have higher
1507
    precedence, otherwise the precedence order is the same as the order of
1508
    arguments in relevant_args, that is, from left-to-right in the argument list.
1509

1510
    The precedence-determining algorithm implemented in this function is
1511
    described in `NEP-0018`_.
1512

1513
    See torch::append_overloaded_arg for the equivalent function in the C++
1514
    implementation.
1515

1516
    Parameters
1517
    ----------
1518
    relevant_args : iterable of array-like
1519
        Iterable of array-like arguments to check for __torch_function__
1520
        methods.
1521

1522
    get_type_fn : callable, optional
1523
        Function to call on each argument in relevant_args to get its type.
1524

1525
    Returns
1526
    -------
1527
    overloaded_args : list
1528
        Arguments from relevant_args on which to call __torch_function__
1529
        methods, in the order in which they should be called.
1530

1531
    .. _NEP-0018:
1532
       https://numpy.org/neps/nep-0018-array-function-protocol.html
1533
    """
1534
    if get_type_fn is None:
1535
        get_type_fn = type
1536

1537
    # If torch function is not enabled, there are no overloaded types
1538
    if not torch._C._is_torch_function_enabled():
1539
        return []
1540
    # Runtime is O(num_arguments * num_unique_types)
1541
    overloaded_types: Set[Type] = set()
1542
    overloaded_args: List[Any] = []
1543
    for arg in relevant_args:
1544
        arg_type = get_type_fn(arg)
1545
        # We only collect arguments if they have a unique type, which ensures
1546
        # reasonable performance even with a long list of possibly overloaded
1547
        # arguments.
1548
        #
1549
        # NB: Important to exclude _disabled_torch_function_impl, otherwise
1550
        # https://github.com/pytorch/pytorch/issues/64687
1551
        if (arg_type not in overloaded_types and hasattr(arg_type, '__torch_function__') and
1552
                arg_type.__torch_function__ != torch._C._disabled_torch_function_impl):
1553
            # Create lists explicitly for the first type (usually the only one
1554
            # done) to avoid setting up the iterator for overloaded_args.
1555
            if overloaded_types:
1556
                overloaded_types.add(arg_type)
1557
                # By default, insert argument at the end, but if it is
1558
                # subclass of another argument, insert it before that argument.
1559
                # This ensures "subclasses before superclasses".
1560
                index = len(overloaded_args)
1561
                for i, old_arg in enumerate(overloaded_args):
1562
                    if issubclass(arg_type, get_type_fn(old_arg)):
1563
                        index = i
1564
                        break
1565
                overloaded_args.insert(index, arg)
1566
            else:
1567
                overloaded_types = {arg_type}
1568
                overloaded_args = [arg]
1569
    return overloaded_args
1570

1571

1572
def handle_torch_function(
1573
        public_api: Callable, relevant_args: Iterable[Any], *args, **kwargs) -> Any:
1574
    """Implement a function with checks for ``__torch_function__`` overrides.
1575

1576
    See torch::autograd::handle_torch_function for the equivalent of this
1577
    function in the C++ implementation.
1578

1579
    Arguments
1580
    ---------
1581
    public_api : function
1582
        Function exposed by the public torch API originally called like
1583
        ``public_api(*args, **kwargs)`` on which arguments are now being
1584
        checked.
1585
    relevant_args : iterable
1586
        Iterable of arguments to check for __torch_function__ methods.
1587
    args : tuple
1588
        Arbitrary positional arguments originally passed into ``public_api``.
1589
    kwargs : tuple
1590
        Arbitrary keyword arguments originally passed into ``public_api``.
1591

1592
    Returns
1593
    -------
1594
    object
1595
        Result from calling ``implementation`` or an ``__torch_function__``
1596
        method, as appropriate.
1597

1598
    Raises
1599
    ------
1600
    TypeError : if no implementation is found.
1601

1602
    Example
1603
    -------
1604
    >>> def func(a):
1605
    ...     if has_torch_function_unary(a):
1606
    ...         return handle_torch_function(func, (a,), a)
1607
    ...     return a + 0
1608
    """
1609
    # Check for __torch_function__ methods.
1610
    overloaded_args = _get_overloaded_args(relevant_args)
1611
    # overloaded_args already have unique types.
1612
    types = tuple(map(type, overloaded_args))
1613

1614
    # Check for __torch_function__ mode.
1615
    if _is_torch_function_mode_enabled():
1616
        # if we're here, the mode must be set to a TorchFunctionStackMode
1617
        # this unsets it and calls directly into TorchFunctionStackMode's torch function
1618
        with _pop_mode_temporarily() as mode:
1619
            result = mode.__torch_function__(public_api, types, args, kwargs)
1620
        if result is not NotImplemented:
1621
            return result
1622

1623
    # Call overrides
1624
    for overloaded_arg in overloaded_args:
1625
        # This call needs to become a classmethod call in the future.
1626
        # See https://github.com/pytorch/pytorch/issues/63767
1627
        torch_func_method = overloaded_arg.__torch_function__
1628
        if hasattr(torch_func_method, "__self__") and torch_func_method.__self__ is overloaded_arg and \
1629
                torch_func_method is not torch._C._disabled_torch_function_impl:
1630
            warnings.warn("Defining your `__torch_function__ as a plain method is deprecated and "
1631
                          "will be an error in future, please define it as a classmethod.",
1632
                          DeprecationWarning)
1633

1634
        # Use `public_api` instead of `implementation` so __torch_function__
1635
        # implementations can do equality/identity comparisons.
1636
        result = torch_func_method(public_api, types, args, kwargs)
1637

1638
        if result is not NotImplemented:
1639
            return result
1640

1641
    func_name = f'{public_api.__module__}.{public_api.__name__}'
1642
    msg = (
1643
        f"no implementation found for '{func_name}' on types that implement "
1644
        f'__torch_function__: {[type(arg) for arg in overloaded_args]}'
1645
    )
1646
    if _is_torch_function_mode_enabled():
1647
        msg += f" nor in mode {_get_current_function_mode()}"
1648
    raise TypeError(msg)
1649

1650
has_torch_function = _add_docstr(
1651
    _has_torch_function,
1652
    r"""Check for __torch_function__ implementations in the elements of an iterable
1653
    or if a __torch_function__ mode is enabled.  Considers exact ``Tensor`` s
1654
    and ``Parameter`` s non-dispatchable.  Use this to guard a call to
1655
    :func:`handle_torch_function`; don't use it to test if something
1656
    is Tensor-like, use :func:`is_tensor_like` instead.
1657
    Arguments
1658
    ---------
1659
    relevant_args : iterable
1660
        Iterable or arguments to check for __torch_function__ methods.
1661
    Returns
1662
    -------
1663
    bool
1664
        True if any of the elements of relevant_args have __torch_function__
1665
        implementations, False otherwise.
1666
    See Also
1667
    ________
1668
    torch.is_tensor_like
1669
        Checks if something is a Tensor-like, including an exact ``Tensor``.
1670
    """
1671
)
1672

1673
has_torch_function_unary = _add_docstr(
1674
    _has_torch_function_unary,
1675
    r"""Special case of `has_torch_function` for single inputs.
1676
    Instead of:
1677
      `has_torch_function((t,))`
1678
    call:
1679
      `has_torch_function_unary(t)`
1680
    which skips unnecessary packing and unpacking work.
1681
    """
1682
)
1683

1684
has_torch_function_variadic = _add_docstr(
1685
    _has_torch_function_variadic,
1686
    r"""Special case of `has_torch_function` that skips tuple creation.
1687

1688
    This uses the METH_FASTCALL protocol introduced in Python 3.7
1689

1690
    Instead of:
1691
      `has_torch_function((a, b))`
1692
    call:
1693
      `has_torch_function_variadic(a, b)`
1694
    which skips unnecessary packing and unpacking work.
1695
    """
1696
)
1697

1698
@functools.lru_cache(None)
1699
def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callable, str]]:
1700
    overridable_funcs = collections.defaultdict(list)
1701
    index = {}
1702
    tested_namespaces = [
1703
        ("torch", torch, torch.__all__),
1704
        ("torch.functional", torch.functional, torch.functional.__all__),
1705
        ("torch.nn.functional", torch.nn.functional, dir(torch.nn.functional)),
1706
        ("torch.nn.init", torch.nn.init, dir(torch.nn.init)),
1707
        ("torch.Tensor", torch.Tensor, dir(torch.Tensor)),
1708
        ("torch.linalg", torch.linalg, dir(torch.linalg)),
1709
        ("torch.fft", torch.fft, dir(torch.fft)),
1710
        ("torch.special", torch.special, dir(torch.special)),
1711
    ]
1712
    for namespace_str, namespace, ns_funcs in tested_namespaces:
1713
        for func_name in ns_funcs:
1714
            ignore = False
1715
            # ignore private functions or functions that are deleted in torch.__init__
1716
            if namespace is not torch.Tensor:
1717
                if func_name.startswith('__'):
1718
                    continue
1719
                elif func_name.startswith('_'):
1720
                    ignore = True
1721
                elif func_name.endswith('_'):
1722
                    ignore = True
1723
                elif not func_name[0].islower():
1724
                    ignore = True
1725
                elif func_name == 'unique_dim':
1726
                    continue
1727
            else:
1728
                func = getattr(namespace, func_name)
1729
                if getattr(object, func_name, None) == func:
1730
                    continue
1731
                if func_name == '__weakref__':
1732
                    continue
1733
            func = getattr(namespace, func_name)
1734
            if namespace is torch.Tensor and getattr(object, func_name, None) == func:
1735
                continue
1736
            # ignore re-exported modules
1737
            if isinstance(func, types.ModuleType):
1738
                continue
1739
            # ignore __future__ imports
1740
            if isinstance(func, __future__._Feature):
1741
                continue
1742

1743
            if not callable(func) and hasattr(func, "__get__"):
1744
                index[func.__get__] = f"{namespace_str}.{func_name}.__get__"
1745
                index[func.__set__] = f"{namespace_str}.{func_name}.__set__"
1746
                if ignore:
1747
                    continue
1748
                if func.__get__ in get_ignored_functions():
1749
                    msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
1750
                           "but still has an explicit override")
1751
                    assert func.__get__ not in get_testing_overrides(), msg.format(namespace, func.__name__)
1752
                    continue
1753
                else:
1754
                    overridable_funcs[func].append(func.__get__)
1755
                    continue
1756

1757
            if not callable(func):
1758
                continue
1759

1760
            index[func] = f"{namespace_str}.{func_name}"
1761

1762
            if ignore:
1763
                continue
1764

1765
            # cannot be overriden by __torch_function__
1766
            if func in get_ignored_functions():
1767
                msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
1768
                       "but still has an explicit override")
1769
                assert func not in get_testing_overrides(), msg.format(namespace, func.__name__)
1770
                continue
1771
            overridable_funcs[namespace].append(func)
1772
    return overridable_funcs, index
1773

1774
@_disable_user_warnings
1775
def get_overridable_functions() -> Dict[Any, List[Callable]]:
1776
    """List functions that are overridable via __torch_function__
1777

1778
    Returns
1779
    -------
1780
    Dict[Any, List[Callable]]
1781
        A dictionary that maps namespaces that contain overridable functions
1782
        to functions in that namespace that can be overridden.
1783
    """
1784
    return _get_overridable_functions()[0]
1785

1786
@_disable_user_warnings
1787
def resolve_name(f):
1788
    """Get a human readable string name for a function passed to
1789
    __torch_function__
1790

1791
    Arguments
1792
    ---------
1793
    f : Callable
1794
        Function to resolve the name of.
1795

1796
    Returns
1797
    -------
1798
    str
1799
        Name of the function; if eval'ed it should give back the input
1800
        function.
1801
    """
1802
    if isinstance(f, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
1803
        return str(f)
1804
    return _get_overridable_functions()[1].get(f)
1805

1806
@functools.lru_cache(None)
1807
def _get_tensor_methods() -> Set[Callable]:
1808
    """ Returns a set of the overridable methods on ``torch.Tensor`` """
1809
    overridable_funcs = get_overridable_functions()
1810
    methods = set(overridable_funcs[torch.Tensor])
1811
    return methods
1812

1813
@_disable_user_warnings
1814
def is_tensor_method_or_property(func: Callable) -> bool:
1815
    """
1816
    Returns True if the function passed in is a handler for a
1817
    method or property belonging to ``torch.Tensor``, as passed
1818
    into ``__torch_function__``.
1819

1820
    .. note::
1821
       For properties, their ``__get__`` method must be passed in.
1822

1823
    This may be needed, in particular, for the following reasons:
1824

1825
    1. Methods/properties sometimes don't contain a `__module__` slot.
1826
    2. They require that the first passed-in argument is an instance
1827
       of ``torch.Tensor``.
1828

1829
    Examples
1830
    --------
1831
    >>> is_tensor_method_or_property(torch.Tensor.add)
1832
    True
1833
    >>> is_tensor_method_or_property(torch.add)
1834
    False
1835
    """
1836
    return func in _get_tensor_methods() or func.__name__ == "__get__"
1837

1838
def is_tensor_like(inp):
1839
    """
1840
    Returns ``True`` if the passed-in input is a Tensor-like.
1841

1842
    Currently, this occurs whenever there's a ``__torch_function__``
1843
    attribute on the type of the input.
1844

1845
    Examples
1846
    --------
1847
    A subclass of tensor is generally a Tensor-like.
1848

1849
    >>> class SubTensor(torch.Tensor): ...
1850
    >>> is_tensor_like(SubTensor([0]))
1851
    True
1852

1853
    Built-in or user types aren't usually Tensor-like.
1854

1855
    >>> is_tensor_like(6)
1856
    False
1857
    >>> is_tensor_like(None)
1858
    False
1859
    >>> class NotATensor: ...
1860
    >>> is_tensor_like(NotATensor())
1861
    False
1862

1863
    But, they can be made Tensor-like by implementing __torch_function__.
1864

1865
    >>> class TensorLike:
1866
    ...     @classmethod
1867
    ...     def __torch_function__(cls, func, types, args, kwargs):
1868
    ...         return -1
1869
    >>> is_tensor_like(TensorLike())
1870
    True
1871
    """
1872
    return type(inp) is torch.Tensor or hasattr(inp, "__torch_function__")
1873

1874
class TorchFunctionMode:
1875
    """
1876
    A ``TorchFunctionMode`` allows you to override the meaning of all
1877
    ``__torch_function__`` overrideable functions within a dynamic scope,
1878
    without having to actually create a tensor subclass or manually
1879
    monkey-patch functions in the PyTorch API.  Some common situations
1880
    where you should use a mode:
1881

1882
        * You want to override the meaning of factory functions, or other
1883
          functions that do not otherwise take a tensor as an argument
1884
          (these cannot be overridden with tensor subclasses).
1885

1886
        * You want to override the behavior of all functions without needing
1887
          to wrap your inputs in tensor subclasses; e.g., if you are just
1888
          interested in logging intermediate computations.
1889

1890
        * You want to control the order of execution of various tensor
1891
          subclasses explicitly, rather than implicitly via the return of
1892
          ``NotImplemented``.
1893

1894
    Independent subclasses of :class:`TorchFunctionMode` are compositional:
1895
    modes can be pushed onto a stack using ``with MyMode():``.
1896
    When you call functions in the PyTorch API inside your
1897
    ``__torch_function__`` implementation, by default, they will forward on to
1898
    the next mode on the mode stack.  If you want recursively call back into
1899
    your current ``__torch_function__`` implementation, either explicitly
1900
    invoke ``self.__torch_function__(...)``, or use the context manager
1901
    ``enable_torch_function_mode(self, replace=self.inner)`` to make PyTorch
1902
    API self-referential (beware of infinite loops, in this case!)
1903
    """
1904
    inner: "TorchFunctionMode"
1905

1906
    # Force metaclass to generate constructor at the base of the hierarchy
1907
    def __init__(self):
1908
        pass
1909

1910
    def __torch_function__(self, func, types, args=(), kwargs=None):
1911
        raise NotImplementedError()
1912

1913
    def __enter__(self):
1914
        _push_mode(self)
1915
        return self
1916

1917
    def __exit__(self, exc_type, exc_val, exc_tb):
1918
        _pop_mode()
1919

1920
    @classmethod
1921
    def push(cls, *args, **kwargs):
1922
        warnings.warn("`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`")
1923
        instance = cls(*args, **kwargs)
1924
        return instance
1925

1926

1927
def _get_current_function_mode():
1928
    stack_len = _len_torch_function_stack()
1929
    return _get_function_stack_at(stack_len - 1) if stack_len > 0 else None
1930

1931

1932
def _get_current_function_mode_stack():
1933
    stack_len = _len_torch_function_stack()
1934
    return [_get_function_stack_at(i) for i in range(stack_len)]
1935

1936
def _push_mode(mode):
1937
    _push_on_torch_function_stack(mode)
1938

1939

1940
def _pop_mode():
1941
    old = _pop_torch_function_stack()
1942
    return old
1943

1944

1945
@contextlib.contextmanager
1946
def _pop_mode_temporarily():
1947
    old = _pop_mode()
1948
    try:
1949
        yield old
1950
    finally:
1951
        _push_mode(old)
1952

1953
class BaseTorchFunctionMode(TorchFunctionMode):
1954
    def __torch_function__(self, func, types, args=(), kwargs=None):
1955
        if kwargs is None:
1956
            kwargs = {}
1957
        return func(*args, **kwargs)
1958

1959

1960
@contextlib.contextmanager
1961
def enable_reentrant_dispatch():
1962
    # NB: this can't simply be
1963
    # `enable_reentrant_dispatch = torch._C._RestorePythonTLSSnapshot`
1964
    # because:
1965
    # 1. torch._C._RestorePythonTLSSnapshot is unavailable when this file
1966
    #    initially gets imported. Probably an import order thing.
1967
    # 2. enable_reentrant_dispatch is technically public API; assigning
1968
    #    it the object would change the __module__ to look private.
1969
    with torch._C._RestorePythonTLSSnapshot():
1970
        try:
1971
            yield
1972
        finally:
1973
            pass
1974

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.