2
Python implementation of ``__torch_function__``
4
While most of the torch API and handling for ``__torch_function__`` happens
5
at the C++ level, some of the torch API is written in Python so we need
6
python-level handling for ``__torch_function__`` overrides as well. The main
7
developer-facing functionality in this file are handle_torch_function and
8
has_torch_function. See torch/functional.py and test/test_overrides.py
13
heavily inspired by NumPy's ``__array_function__`` (see:
14
https://github.com/pytorch/pytorch/issues/24015 and
15
https://www.numpy.org/neps/nep-0018-array-function-protocol.html
18
If changing this file in a way that can affect ``__torch_function__`` overhead,
19
please report the benchmarks in ``benchmarks/overrides_benchmark``. See the
20
instructions in the ``README.md`` in that directory.
29
from typing import Dict, Set, List, Any, Callable, Iterable, Type, Tuple
30
from functools import wraps
35
_has_torch_function, _has_torch_function_unary,
36
_has_torch_function_variadic, _add_docstr,
37
_push_on_torch_function_stack, _pop_torch_function_stack, _get_function_stack_at, _len_torch_function_stack,
38
_is_torch_function_mode_enabled)
41
"get_ignored_functions",
42
"get_overridable_functions",
43
"get_testing_overrides",
44
"handle_torch_function",
48
"is_tensor_method_or_property",
49
"wrap_torch_function",
50
"enable_reentrant_dispatch",
54
def _disable_user_warnings(
55
func: Callable, regex: str = '.*is deprecated, please use.*', module: str = 'torch') -> Callable:
57
Decorator that temporarily disables ``UserWarning``s for the given ``module`` if the warning message matches the
58
given ``regex`` pattern.
63
Function to disable the warnings for.
65
A regex pattern compilable by ``re.compile``. This is used to match the ``UserWarning`` message.
67
The python module to which the filtering should be restricted.
76
def wrapper(*args, **kwargs):
77
with warnings.catch_warnings():
78
warnings.filterwarnings("ignore", category=UserWarning, message=regex, module=module)
79
return func(*args, **kwargs)
83
@functools.lru_cache(None)
84
@_disable_user_warnings
85
def get_ignored_functions() -> Set[Callable]:
87
Return public functions that cannot be overridden by ``__torch_function__``.
92
A tuple of functions that are publicly available in the torch API but cannot
93
be overridden with ``__torch_function__``. Mostly this is because none of the
94
arguments of these functions are tensors or tensor-likes.
98
>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions()
100
>>> torch.add in torch.overrides.get_ignored_functions()
103
Tensor = torch.Tensor
108
torch.set_default_tensor_type,
109
torch.set_default_device,
110
torch.get_default_device,
118
torch.set_printoptions,
120
torch.get_default_dtype,
121
torch.get_num_interop_threads,
122
torch.get_num_threads,
123
torch.init_num_threads,
124
torch.import_ir_module,
125
torch.import_ir_module_from_buffer,
126
torch.is_anomaly_enabled,
127
torch.is_anomaly_check_nan_enabled,
128
torch.is_grad_enabled,
129
torch.merge_type_from_type_comment,
132
torch.parse_type_comment,
133
torch.set_anomaly_enabled,
134
torch.set_flush_denormal,
135
torch.set_num_interop_threads,
136
torch.set_num_threads,
142
torch.default_generator,
156
torch.set_grad_enabled,
159
torch.inference_mode,
160
torch.is_inference_mode_enabled,
165
torch.bartlett_window,
166
torch.blackman_window,
167
torch.broadcast_shapes,
170
torch.cudnn_affine_grid_generator,
171
torch.cudnn_batch_norm,
172
torch.cudnn_convolution,
173
torch.cudnn_convolution_transpose,
174
torch.cudnn_convolution_relu,
175
torch.cudnn_convolution_add_relu,
176
torch.cudnn_grid_sampler,
177
torch.cudnn_is_acceptable,
179
torch.empty_permuted,
181
torch.empty_quantized,
182
torch.export.dynamic_dim,
185
torch.export.register_dataclass,
193
torch.hamming_window,
198
torch.mkldnn_adaptive_avg_pool2d,
199
torch.mkldnn_convolution,
200
torch.mkldnn_max_pool2d,
201
torch.mkldnn_max_pool3d,
202
torch.mkldnn_linear_backward_weights,
203
torch.mkldnn_rnn_layer,
214
torch.sparse_coo_tensor,
215
torch.sparse_compressed_tensor,
216
torch.sparse_csr_tensor,
217
torch.sparse_csc_tensor,
218
torch.sparse_bsr_tensor,
219
torch.sparse_bsc_tensor,
220
torch.sym_constrain_range,
221
torch.sym_constrain_range_for_size,
226
torch._jit_internal.boolean_dispatch,
227
torch.nn.functional.assert_int_or_pair,
228
torch.nn.functional.upsample,
229
torch.nn.functional.upsample_bilinear,
230
torch.nn.functional.upsample_nearest,
231
torch.nn.functional.has_torch_function,
232
torch.nn.functional.has_torch_function_unary,
233
torch.nn.functional.has_torch_function_variadic,
234
torch.nn.functional.handle_torch_function,
235
torch.nn.functional.sigmoid,
236
torch.nn.functional.hardsigmoid,
237
torch.nn.functional.tanh,
238
torch.nn.functional._canonical_mask,
239
torch.nn.functional._none_or_dtype,
241
torch.nn.init.calculate_gain,
243
torch.nn.init.uniform,
244
torch.nn.init.normal,
245
torch.nn.init.constant,
248
torch.nn.init.xavier_uniform,
249
torch.nn.init.xavier_normal,
250
torch.nn.init.kaiming_uniform,
251
torch.nn.init.kaiming_normal,
252
torch.nn.init.orthogonal,
253
torch.nn.init.sparse,
254
torch.nested.to_padded_tensor,
256
handle_torch_function,
257
torch.set_autocast_enabled,
258
torch.is_autocast_enabled,
259
torch.clear_autocast_cache,
260
torch.set_autocast_cpu_enabled,
261
torch.is_autocast_cpu_enabled,
262
torch.set_autocast_xla_enabled,
263
torch.is_autocast_xla_enabled,
264
torch.set_autocast_ipu_enabled,
265
torch.is_autocast_ipu_enabled,
266
torch.set_autocast_cpu_dtype,
267
torch.get_autocast_cpu_dtype,
268
torch.set_autocast_ipu_dtype,
269
torch.get_autocast_ipu_dtype,
270
torch.get_autocast_gpu_dtype,
271
torch.set_autocast_gpu_dtype,
272
torch.get_autocast_xla_dtype,
273
torch.set_autocast_xla_dtype,
274
torch.autocast_increment_nesting,
275
torch.autocast_decrement_nesting,
276
torch.is_autocast_cache_enabled,
277
torch.set_autocast_cache_enabled,
278
torch.nn.functional.hardswish,
279
torch.is_vulkan_available,
280
torch.are_deterministic_algorithms_enabled,
281
torch.use_deterministic_algorithms,
282
torch.is_deterministic_algorithms_warn_only_enabled,
283
torch.set_deterministic_debug_mode,
284
torch.get_deterministic_debug_mode,
285
torch.set_float32_matmul_precision,
286
torch.get_float32_matmul_precision,
287
torch.unify_type_list,
288
torch.is_warn_always_enabled,
289
torch.set_warn_always,
290
torch.vitals_enabled,
297
torch._functional_sym_constrain_range,
298
torch._make_dep_token,
301
Tensor.__getattribute__,
304
Tensor.__init_subclass__,
307
Tensor.__torch_function__,
308
Tensor.__torch_dispatch__,
311
Tensor.__subclasshook__,
320
Tensor.new_empty_strided,
324
Tensor._make_subclass,
329
Tensor.to_sparse_coo,
330
Tensor.to_sparse_csr,
331
Tensor.to_sparse_csc,
332
Tensor.to_sparse_bsr,
333
Tensor.to_sparse_bsc,
335
Tensor._to_sparse_csr,
336
Tensor._to_sparse_csc,
337
Tensor._to_sparse_bsr,
338
Tensor._to_sparse_bsc,
339
Tensor._typed_storage,
340
Tensor._reduce_ex_internal,
343
Tensor._view_func_unsafe,
344
Tensor._rev_view_func_unsafe,
345
Tensor._make_wrapper_subclass,
346
Tensor._python_dispatch.__get__,
347
Tensor._has_symbolic_sizes_strides.__get__,
349
Tensor._conj_physical,
352
Tensor._is_zerotensor,
355
Tensor._addmm_activation,
356
Tensor.to_padded_tensor,
360
@functools.lru_cache(None)
361
def get_default_nowrap_functions() -> Set[Callable]:
363
Return public functions that do not wrap in a subclass when invoked by
364
the default ``Tensor.__torch_function__`` that preserves subclasses. Typically,
365
these functions represent field accesses (i.e., retrieving a Tensor that
366
is stored somewhere on the Tensor) as opposed to computation. Users of
367
these functions expect object identity to be preserved over multiple accesses
368
(e.g., ``a.grad is a.grad``) which cannot be upheld if we're wrapping on
369
the fly every time (furthermore, the tensor stored here might already be
370
the subclass, in which case wrapping really ought not to happen).
372
Not ALL property accessors have this property; for example ``Tensor.T`` actually
373
just creates a new transposed tensor on the fly, and so we SHOULD interpose on
374
these calls (you need to check the implementation of the function to see if
375
this is the case or not). Additionally, if a property accessor doesn't return a Tensor,
376
it doesn't have to be on this list (though it is harmless if it is).
378
Tensor = torch.Tensor
380
Tensor._base.__get__,
382
Tensor._grad.__get__,
386
@functools.lru_cache(None)
387
@_disable_user_warnings
388
def get_testing_overrides() -> Dict[Callable, Callable]:
389
"""Return a dict containing dummy overrides for all overridable functions
393
Dict[Callable, Callable]
394
A dictionary that maps overridable functions in the PyTorch API to
395
lambda functions that have the same signature as the real function
396
and unconditionally return -1. These lambda functions are useful
397
for testing API coverage for a type that defines ``__torch_function__``.
402
>>> my_add = torch.overrides.get_testing_overrides()[torch.add]
403
>>> inspect.signature(my_add)
404
<Signature (input, other, out=None)>
413
Tensor = torch.Tensor
414
ret: Dict[Callable, Callable] = {
415
torch.abs: lambda input, out=None: -1,
416
torch.absolute: lambda input, out=None: -1,
417
torch.adaptive_avg_pool1d: lambda input, output_size: -1,
418
torch.adaptive_max_pool1d: lambda inputs, output_size: -1,
419
torch.acos: lambda input, out=None: -1,
420
torch.adjoint: lambda input: -1,
421
torch.arccos: lambda input, out=None: -1,
422
torch.acosh: lambda input, out=None: -1,
423
torch.arccosh: lambda input, out=None: -1,
424
torch.add: lambda input, other, out=None: -1,
425
torch.addbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1,
426
torch.addcdiv: lambda input, tensor1, tensor2, value=1, out=None: -1,
427
torch.addcmul: lambda input, tensor1, tensor2, value=1, out=None: -1,
428
torch.addmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
429
torch.addmv: lambda input, mat, vec, beta=1, alpha=1, out=None: -1,
430
torch.addr: lambda input, vec1, vec2, beta=1, alpha=1, out=None: -1,
431
torch.affine_grid_generator: lambda theta, size, align_corners: -1,
432
torch.all: lambda input, dim=None: -1,
433
torch.allclose: lambda input, other, trol=1e-05, atol=1e-08, equal_nan=False: -1,
434
torch.alpha_dropout: lambda input, p, train, inplace=False: -1,
435
torch.amax: lambda input, dim=None: -1,
436
torch.amin: lambda input, dim=None: -1,
437
torch.aminmax: lambda input, dim=None, keepdim=False, out=None: -1,
438
torch.angle: lambda input, out=None: -1,
439
torch.any: lambda input, dim=None, keepdim=False, out=None: -1,
440
torch.argmax: lambda input: -1,
441
torch.argmin: lambda input: -1,
442
torch.argsort: lambda input, dim=None: -1,
443
torch.asin: lambda input, out=None: -1,
444
torch._assert_async: lambda input, msg: -1,
445
torch.arcsin: lambda input, out=None: -1,
446
torch.asinh: lambda input, out=None: -1,
447
torch.arcsinh: lambda input, out=None: -1,
448
torch.atan: lambda input, out=None: -1,
449
torch.arctan: lambda input, out=None: -1,
450
torch.atan2: lambda input, other, out=None: -1,
451
torch.arctan2: lambda input, other, out=None: -1,
452
torch.atanh: lambda input, out=None: -1,
453
torch.arctanh: lambda input, out=None: -1,
454
torch.atleast_1d: lambda *tensors: -1,
455
torch.atleast_2d: lambda *tensors: -1,
456
torch.atleast_3d: lambda *tensors: -1,
457
torch.avg_pool1d: lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True: -1,
458
torch.baddbmm: lambda input, batch1, batch2, alpha=1, beta=1, out=None: -1,
459
torch.batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled: -1,
460
torch.batch_norm_backward_elemt: lambda grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count_tensor: -1,
461
torch.batch_norm_backward_reduce: lambda grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g: -1,
462
torch.batch_norm_elemt: lambda input, weight, bias, mean, invstd, eps: -1,
463
torch.batch_norm_gather_stats: lambda input, mean, invstd, running_mean, running_var, momentum, eps, count: -1,
464
torch.batch_norm_gather_stats_with_counts: lambda input, mean, invstd, running_mean, running_var, momentum, eps, count: -1,
465
torch.batch_norm_stats: lambda input, eps: -1,
466
torch.batch_norm_update_stats: lambda input, running_mean, running_var, momentum: -1,
467
torch.bernoulli: lambda input, generator=None, out=None: -1,
468
torch.bilinear: lambda input1, input2, weight, bias: -1,
469
torch.binary_cross_entropy_with_logits: (lambda input, target, weight=None, size_average=None, reduce=None,
470
reduction='mean', pos_weight=None: -1),
471
torch.bincount: lambda input, weights=None, minlength=0: -1,
472
torch.binomial: lambda count, prob, generator=None: -1,
473
torch.bitwise_and: lambda input, other, out=None: -1,
474
torch.bitwise_not: lambda input, out=None: -1,
475
torch.bitwise_or: lambda input, other, out=None: -1,
476
torch.bitwise_xor: lambda input, other, out=None: -1,
477
torch.bitwise_left_shift: lambda input, other, out=None: -1,
478
torch.bitwise_right_shift: lambda input, other, out=None: -1,
479
torch.block_diag: lambda *tensors: -1,
480
torch.bmm: lambda input, mat2, out=None: -1,
481
torch.broadcast_tensors: lambda *tensors: -1,
482
torch.broadcast_to: lambda self, size: -1,
483
torch.bucketize: lambda input, boundaries, out_int32=False, right=False, out=None: -1,
484
torch.cartesian_prod: lambda *tensors: -1,
485
torch.cat: lambda tensors, dim=0, out=None: -1,
486
torch.concat: lambda tensors, dim=0, out=None: -1,
487
torch.concatenate: lambda tensors, dim=0, out=None: -1,
488
torch.cdist: lambda x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary': -1,
489
torch.ceil: lambda input, out=None: -1,
490
torch.celu: lambda input, alpha=1., inplace=False: -1,
491
torch.chain_matmul: lambda *matrices, out=None: -1,
492
torch.channel_shuffle: lambda input, groups : -1,
493
torch.cholesky: lambda input, upper=False, out=None: -1,
494
torch.linalg.cholesky: lambda input, out=None: -1,
495
torch.linalg.cholesky_ex: lambda input, check_errors=False, out=None: -1,
496
torch.cholesky_inverse: lambda input, upper=False, out=None: -1,
497
torch.cholesky_solve: lambda input1, input2, upper=False, out=None: -1,
498
torch.choose_qparams_optimized: lambda input, numel, n_bins, ratio, bit_width: -1,
499
torch.chunk: lambda input, chunks, dim=0: -1,
500
torch.clamp: lambda input, min=None, max=None, out=None: -1,
501
torch.clip: lambda input, min=None, max=None, out=None: -1,
502
torch.clamp_min: lambda input, min, out=None: -1,
503
torch.clamp_max: lambda input, max, out=None: -1,
504
torch.column_stack: lambda tensors, out=None: -1,
505
torch.cov: lambda input, correction=1, fweights=None, aweights=None: -1,
506
torch.clone: lambda input: -1,
507
torch.combinations: lambda input, r=2, with_replacement=False: -1,
508
torch.complex: lambda real, imag: -1,
509
torch.copysign: lambda input, other, out=None: -1,
510
torch.polar: lambda abs, ang: -1,
511
torch.linalg.cond: lambda input, ord=None: -1,
512
torch.conj: lambda input, out=None: -1,
513
torch.conj_physical: lambda input, out=None: -1,
514
torch.resolve_conj: lambda input, out=None: -1,
515
torch.resolve_neg: lambda input, out=None: -1,
516
torch.constant_pad_nd: lambda input, pad, value=0: -1,
517
torch.conv1d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
518
torch.conv2d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
519
torch.conv3d: lambda input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1: -1,
520
torch.convolution: lambda input, weight, bias, stride, padding, dilation, transposed, output_adding, groups: -1,
521
torch.conv_tbc: lambda input, weight, bias, pad=0: -1,
522
torch.conv_transpose1d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
523
torch.conv_transpose2d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
524
torch.conv_transpose3d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
525
torch.corrcoef: lambda input: -1,
526
torch.cos: lambda input, out=None: -1,
527
torch.cosine_embedding_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1,
528
torch.cosh: lambda input, out=None: -1,
529
torch.cosine_similarity: lambda x1, x2, dim=1, eps=1e-8: -1,
530
torch.count_nonzero: lambda input: -1,
531
torch.cross: lambda input, other, dim=None, out=None: -1,
532
torch.linalg.cross: lambda input, other, dim=-1, out=None: -1,
533
torch.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean',
534
zero_infinity=False: -1),
535
torch.cummax: lambda input, dim, out=None: -1,
536
torch.cummin: lambda input, dim, out=None: -1,
537
torch.cumprod: lambda input, dim, out=None, dtype=None: -1,
538
torch.cumsum: lambda input, dim, out=None, dtype=None: -1,
539
torch.cumulative_trapezoid: lambda y, x=None, dim=-1: -1,
540
torch.logcumsumexp: lambda input, dim, out=None: -1,
541
torch.deg2rad: lambda input, out=None: -1,
542
torch.dequantize: lambda input: -1,
543
torch.det: lambda input: -1,
544
torch.linalg.det: lambda input: -1,
545
torch.detach: lambda input: -1,
546
torch.diag: lambda input, diagonal=0, out=None: -1,
547
torch.diag_embed: lambda input, diagonal=0, out=None: -1,
548
torch.diagflat: lambda input, offset=0: -1,
549
torch.diff: lambda input, n=1, dim=-1, prepend=None, append=None, out=None: -1,
550
torch.diagonal: lambda input, offset=0, dim1=0, dim2=1: -1,
551
torch.linalg.diagonal: lambda input, offset=0, dim1=-2, dim2=-1: -1,
552
torch.diagonal_scatter: lambda input, src, offset=0, dim1=0, dim2=1: -1,
553
torch.as_strided_scatter: lambda self, src, size, stride, storage_offset=None: -1,
554
torch.digamma: lambda input, out=None: -1,
555
torch.dist: lambda input, other, p=2: -1,
556
torch.div: lambda input, other, rounding_mode=None, out=None: -1,
557
torch.divide: lambda input, other, rounding_mode=None, out=None: -1,
558
torch.dot: lambda input, other, out=None: -1,
559
torch.dropout: lambda input, p, train, inplace=False: -1,
560
torch.dsmm: lambda input, mat2: -1,
561
torch.hsmm: lambda mat1, mat2: -1,
562
torch.dsplit: lambda input, indices_or_sections: -1,
563
torch.dstack: lambda tensors, out=None: -1,
564
torch.linalg.eig: lambda input, out=None: -1,
565
torch.linalg.eigvals: lambda input, out=None: -1,
566
torch.linalg.eigh: lambda input, UPLO="L", out=None: -1,
567
torch.linalg.eigvalsh: lambda input, UPLO="L", out=None: -1,
568
torch.einsum: lambda equation, *operands: -1,
569
torch.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False,
571
torch.embedding_bag: (lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False,
572
mode='mean', sparse=False, per_sample_weights=None, padding_idx=None: -1),
573
torch.empty_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
574
torch.eq: lambda input, other, out=None: -1,
575
torch.equal: lambda input, other: -1,
576
torch.erf: lambda input, out=None: -1,
577
torch.erfc: lambda input, out=None: -1,
578
torch.erfinv: lambda input, out=None: -1,
579
torch.exp: lambda input, out=None: -1,
580
torch.exp2: lambda input, out=None: -1,
581
torch.expm1: lambda input, out=None: -1,
582
torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1,
583
torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1,
584
torch.fused_moving_avg_obs_fake_quant: (lambda x, observer_on, fake_quant_on, averaging_const, running_min,
585
running_max, scale, zero_point, quant_min, quant_max, ch_axis,
586
per_row_fake_quant=False, symmetric_quant=False: -1),
587
torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1,
588
torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1,
589
torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1,
590
torch.fbgemm_linear_int8_weight_fp32_activation: (lambda input, weight, packed, col_offsets, weight_scale,
591
weight_zero_point, bias: -1),
592
torch.fbgemm_linear_quantize_weight: lambda input: -1,
593
torch.fbgemm_pack_gemm_matrix_fp16: lambda input: -1,
594
torch.fbgemm_pack_quantized_matrix: lambda input, a, b: -1,
595
torch.feature_alpha_dropout: lambda input, p, train: -1,
596
torch.feature_dropout: lambda input, p, train: -1,
597
torch.fft.ifft: lambda input, n=None, dim=-1, norm=None: -1,
598
torch.fft.rfft: lambda input, n=None, dim=-1, norm=None: -1,
599
torch.fft.irfft: lambda input, n=None, dim=-1, norm=None: -1,
600
torch.fft.hfft: lambda input, n=None, dim=-1, norm=None: -1,
601
torch.fft.ihfft: lambda input, n=None, dim=-1, norm=None: -1,
602
torch.fft.hfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
603
torch.fft.ihfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
604
torch.fft.hfftn: lambda input, s=None, dim=-1, norm=None: -1,
605
torch.fft.ihfftn: lambda input, s=None, dim=-1, norm=None: -1,
606
torch.fft.fftn: lambda input, s=None, dim=None, norm=None: -1,
607
torch.fft.ifftn: lambda input, s=None, dim=None, norm=None: -1,
608
torch.fft.rfftn: lambda input, s=None, dim=None, norm=None: -1,
609
torch.fft.irfftn: lambda input, s=None, dim=None, norm=None: -1,
610
torch.fft.fft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
611
torch.fft.ifft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
612
torch.fft.rfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
613
torch.fft.irfft2: lambda input, s=None, dim=(-2, -1), norm=None: -1,
614
torch.fft.fftshift: lambda input, dim=None: -1,
615
torch.fft.ifftshift: lambda input, dim=None: -1,
616
torch.fft.fft: lambda input, n=None, dim=-1, norm=None: -1,
617
torch.fix: lambda input, out=None: -1,
618
torch.flatten: lambda input, start_dim=0, end_dim=-1: -1,
619
torch.flip: lambda input, dims: -1,
620
torch.fliplr: lambda input: -1,
621
torch.flipud: lambda input: -1,
622
torch.frobenius_norm: lambda input, dim=None, keepdim=False, out=None: -1,
623
torch.floor: lambda input, out=None: -1,
624
torch.floor_divide: lambda input, other: -1,
625
torch.float_power: lambda input, exponent, out=None: -1,
626
torch.fmod: lambda input, other, out=None: -1,
627
torch.frac: lambda input, out=None: -1,
628
torch.frexp: lambda input, out=None: -1,
629
torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
630
torch._functional_assert_async: lambda input, msg, dep_token: -1,
631
torch.lu_unpack: lambda LU_data, LU_pivots, unpack_data=True, unpack_pivots=True: -1,
632
torch.gather: lambda input, dim, index, out=None, sparse_grad=False: -1,
633
torch.gcd: lambda input, other, out=None: -1,
634
torch.ge: lambda input, other, out=None: -1,
635
torch.greater_equal: lambda input, other, out=None: -1,
636
torch.geqrf: lambda input, out=None: -1,
637
torch.i0: lambda input, out=None: -1,
638
torch.inner: lambda input, other, out=None: -1,
639
torch.outer: lambda input, vec2, out=None: -1,
640
torch.ger: lambda input, vec2, out=None: -1,
641
torch.gradient: lambda input, spacing=None, dim=None, edge_order=1: -1,
642
torch.grid_sampler: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,
643
torch.grid_sampler_2d: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,
644
torch.grid_sampler_3d: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,
645
torch.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05, cudnn_enabled=True: -1,
646
torch.gru: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
647
torch.gru_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
648
torch.gt: lambda input, other, out=None: -1,
649
torch.greater: lambda input, other, out=None: -1,
650
torch.hardshrink: lambda input, lambd=0.5: -1,
651
torch.heaviside: lambda input, values, out=None: -1,
652
torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction='mean': -1,
653
torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1,
654
torch.histogram: lambda input, bins=100, min=None, max=None, weight=None, density=False, out=None: -1,
655
torch.histogramdd: lambda input, bins, range=None, weight=None, density=False: -1,
656
torch.linalg.householder_product: lambda input, tau: -1,
657
torch.hspmm: lambda mat1, mat2, out=None: -1,
658
torch.hsplit: lambda input, indices_or_sections: -1,
659
torch.hstack: lambda tensors, out=None: -1,
660
torch.hypot: lambda input, other, out=None: -1,
661
torch.igamma: lambda input, other, out=None: -1,
662
torch.igammac: lambda input, other, out=None: -1,
663
torch.imag: lambda input, out=None: -1,
664
torch.index_add: lambda input, dim, index, source: -1,
665
torch.index_copy: lambda input, dim, index, source: -1,
666
torch.index_put: lambda input, indices, values, accumulate=False: -1,
667
torch.index_select: lambda input, dim, index, out=None: -1,
668
torch.index_fill: lambda input, dim, index, value: -1,
669
torch.index_reduce: lambda input, dim, index, source, reduce, include_input=True: -1,
670
torch.isfinite: lambda tensor: -1,
671
torch.isin: lambda e, te, assume_unique=False, invert=False: -1,
672
torch.isinf: lambda tensor: -1,
673
torch.isreal: lambda tensor: -1,
674
torch.isposinf: lambda input, out=None: -1,
675
torch.isneginf: lambda input, out=None: -1,
676
torch.instance_norm: (lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps,
678
torch.int_repr: lambda input: -1,
679
torch.inverse: lambda input, out=None: -1,
680
torch.linalg.inv: lambda input, out=None: -1,
681
torch.linalg.inv_ex: lambda input, check_errors=False, out=None: -1,
682
torch.is_complex: lambda input: -1,
683
torch.is_conj: lambda input: -1,
684
torch.is_neg: lambda input: -1,
685
torch.is_distributed: lambda input: -1,
686
torch.is_inference: lambda input: -1,
687
torch.is_floating_point: lambda input: -1,
688
torch.is_nonzero: lambda input: -1,
689
torch.is_same_size: lambda input, other: -1,
690
torch.is_signed: lambda input: -1,
691
torch.isclose: lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1,
692
torch.isnan: lambda input: -1,
693
torch.istft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True,
694
normalized=False, onesided=None, length=None, return_complex=False: -1),
695
torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1,
696
torch.kron: lambda input, other: -1,
697
torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1,
698
torch.linalg.ldl_factor_ex: lambda input, hermitian=False, check_errors=False, out=None: -1,
699
torch.linalg.ldl_factor: lambda input, hermitian=False, out=None: -1,
700
torch.linalg.ldl_solve: lambda LD, pivots, B, hermitian=False, out=None: -1,
701
torch.layer_norm: lambda input, normalized_shape, weight=None, bias=None, esp=1e-05, cudnn_enabled=True: -1,
702
torch.lcm: lambda input, other, out=None: -1,
703
torch.ldexp: lambda input, other, out=None: -1,
704
torch.le: lambda input, other, out=None: -1,
705
torch.less_equal: lambda input, other, out=None: -1,
706
torch.lerp: lambda input, end, weight, out=None: -1,
707
torch.lgamma: lambda input, out=None: -1,
708
torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None,
709
tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1,
710
torch.log: lambda input, out=None: -1,
711
torch.log_softmax: lambda input, dim, dtype=None: -1,
712
torch.log10: lambda input, out=None: -1,
713
torch.log1p: lambda input, out=None: -1,
714
torch.log2: lambda input, out=None: -1,
715
torch.logaddexp: lambda input, other, out=None: -1,
716
torch.logaddexp2: lambda input, other, out=None: -1,
717
torch.logdet: lambda input: -1,
718
torch.xlogy: lambda x, y, out=None: -1,
719
torch.logical_and: lambda input, other, out=None: -1,
720
torch.logical_not: lambda input, out=None: -1,
721
torch.logical_or: lambda input, other, out=None: -1,
722
torch.logical_xor: lambda input, other, out=None: -1,
723
torch.logit: lambda input, eps=None: -1,
724
torch.logsumexp: lambda input, names, keepdim=False, out=None: -1,
725
torch.lstm: lambda data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional: -1,
726
torch.lstm_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
727
torch.lt: lambda input, other, out=None: -1,
728
torch.less: lambda input, other, out=None: -1,
729
torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1,
730
torch.lu_solve: lambda b, LU_data, LU_pivots, out=None: -1,
731
torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1,
732
torch.masked_fill: lambda input, mask, value: -1,
733
torch.masked_scatter: lambda input, mask, source: -1,
734
torch.masked_select: lambda input, mask, out=None: -1,
735
torch.matmul: lambda input, other, out=None: -1,
736
torch.linalg.lu: lambda input, pivot=True, out=None: -1,
737
torch.linalg.lu_factor: lambda input, pivot=True, out=None: -1,
738
torch.linalg.lu_factor_ex: lambda input, pivot=True, check_errors=False, out=None: -1,
739
torch.linalg.lu_solve: lambda LU, pivots, B, left=True, adjoint=False, out=None: -1,
740
torch.linalg.matmul: lambda input, other, out=None: -1,
741
torch.matrix_power: lambda input, n: -1,
742
torch.linalg.matrix_power: lambda input, n, out=None: -1,
743
torch.linalg.matrix_rank: lambda input, tol=None, hermitian=False: -1,
744
torch.linalg.multi_dot: lambda tensors, out=None: -1,
745
torch.matrix_exp: lambda input: -1,
746
torch.linalg.matrix_exp: lambda input: -1,
747
torch.max: lambda input, out=None: -1,
748
torch.maximum: lambda input, other, out=None: -1,
749
torch.fmax: lambda input, other, out=None: -1,
750
torch.max_pool1d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
751
torch.max_pool2d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
752
torch.max_pool3d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
753
torch.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
754
return_indices=False, ceil_mode=False: -1),
755
torch.mean: lambda input, dim=None: -1,
756
torch.nanmean: lambda input, dim=None, keepdim=False, dtype=None, out=None: -1,
757
torch.median: lambda input, dim=None: -1,
758
torch.nanmedian: lambda input, dim=None: -1,
759
torch.meshgrid: lambda *tensors, **kwargs: -1,
760
torch.min: lambda input, out=None: -1,
761
torch.minimum: lambda input, other, out=None: -1,
762
torch.fmin: lambda input, other, out=None: -1,
763
torch.miopen_batch_norm: (lambda input, weight, bias, running_mean, running_var, training,
764
exponential_average_factor, epsilon: -1),
765
torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1,
766
torch.miopen_convolution_add_relu: lambda input, weight, z, alpha, bias, stride, padding, dilation, groups: -1,
767
torch.miopen_convolution_relu: lambda input, weight, bias, stride, padding, dilation, groups: -1,
768
torch.miopen_convolution_transpose: (lambda input, weight, bias, padding, output_padding, stride, dilation,
769
groups, benchmark, deterministic: -1),
770
torch.miopen_depthwise_convolution: (lambda input, weight, bias, padding, stride, dilation, groups, benchmark,
772
torch.miopen_rnn: (lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first,
773
dropout, train, bidirectional, batch_sizes, dropout_state: -1),
774
torch.mm: lambda input, mat2, out=None: -1,
775
torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1,
776
torch.movedim: lambda input, source, destination: -1,
777
torch.moveaxis: lambda input, source, destination: -1,
778
torch.msort: lambda input, descending=False, out=None: -1,
779
torch.mul: lambda input, other, out=None: -1,
780
torch.multiply: lambda input, other, out=None: -1,
781
torch.multinomial: lambda input, num_samples, replacement=False, out=None: -1,
782
torch.mv: lambda input, vec, out=None: -1,
783
torch.mvlgamma: lambda input, p: -1,
784
torch.narrow: lambda input, dim, start, length: -1,
785
torch.nan_to_num: lambda input, nan=0.0, posinf=None, neginf=None, out=None: -1,
786
torch.native_batch_norm: lambda input, weight, bias, running_mean, running_var, training, momentum, eps: -1,
787
torch._native_batch_norm_legit: lambda input, weight, bias, training, momentum, eps: -1,
788
torch.native_dropout: lambda input, p, train: -1,
789
torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
790
torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1,
791
torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1,
792
torch.native_channel_shuffle: lambda input, groups : -1,
793
torch.ne: lambda input, other, out=None: -1,
794
torch.not_equal: lambda input, other, out=None: -1,
795
torch.neg: lambda input, out=None: -1,
796
torch.negative: lambda input, out=None: -1,
797
torch.nextafter: lambda input, other, out=None: -1,
798
torch.nn.functional.adaptive_avg_pool2d: lambda input, output_size: -1,
799
torch.nn.functional.adaptive_avg_pool3d: lambda input, output_size: -1,
800
torch.nn.functional.adaptive_max_pool1d: lambda input, output_size, return_indices=False: -1,
801
torch.nn.functional.adaptive_max_pool1d_with_indices: lambda input, output_size, return_indices=False: -1,
802
torch.nn.functional.adaptive_max_pool2d: lambda input, output_size, return_indices=False: -1,
803
torch.nn.functional.adaptive_max_pool2d_with_indices: lambda input, output_size, return_indices=False: -1,
804
torch.nn.functional.adaptive_max_pool3d: lambda input, output_size, return_indices=False: -1,
805
torch.nn.functional.adaptive_max_pool3d_with_indices: lambda input, output_size, return_indices=False: -1,
806
torch.nn.functional.affine_grid: lambda theta, size, align_corners=None: -1,
807
torch.nn.functional.alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
808
torch.nn.functional.avg_pool2d: (lambda input, kernel_size, stride=None, padding=0, ceil_mode=False,
809
count_include_pad=True, divisor_override=None: -1),
810
torch.nn.functional.avg_pool3d: (lambda input, kernel_size, stride=None, padding=0, ceil_mode=False,
811
count_include_pad=True, divisor_override=None: -1),
812
torch.nn.functional.batch_norm: (lambda input, running_mean, running_var, weight=None, bias=None, training=False,
813
momentum=0.1, eps=1e-05: -1),
814
torch.nn.functional.bilinear: lambda input1, input2, weight, bias=None: -1,
815
torch.nn.functional.binary_cross_entropy: (lambda input, target, weight=None, size_average=None, reduce=None,
816
reduction="mean": -1),
817
torch.nn.functional.binary_cross_entropy_with_logits: (lambda input, target, weight=None, size_average=None,
818
reduce=None, reduction="mean", pos_weight=None: -1),
819
torch.nn.functional.celu: lambda input, alpha=1.0, inplace=False: -1,
820
torch.nn.functional.cosine_embedding_loss: (lambda input1, input2, target, margin=0, size_average=None,
821
reduce=None, reduction='mean': -1),
822
torch.nn.functional.cross_entropy: (lambda input, target, weight=None, size_average=None, ignore_index=-100,
823
reduce=None, reduction="mean", label_smoothing=0.0: -1),
824
torch.nn.functional.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0,
825
reduction='mean', zero_infinity=False: -1),
826
torch.nn.functional.dropout: lambda input, p=0.5, training=True, inplace=False: -1,
827
torch.nn.functional.dropout1d: lambda input, p=0.5, training=True, inplace=False: -1,
828
torch.nn.functional.dropout2d: lambda input, p=0.5, training=True, inplace=False: -1,
829
torch.nn.functional.dropout3d: lambda input, p=0.5, training=True, inplace=False: -1,
830
torch.nn.functional.elu: lambda input, alpha=1.0, inplace=False: -1,
831
torch.nn.functional.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0,
832
scale_grad_by_freq=False, sparse=False: -1),
833
torch.nn.functional.embedding_bag: (lambda input, weight, offsets=None, max_norm=None, norm_type=2,
834
scale_grad_by_freq=False, mode='mean', sparse=False, per_sample_weights=None,
835
include_last_offset=False, padding_idx=None: -1),
836
torch.nn.functional.feature_alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
837
torch.nn.functional.fold: lambda input, output_size, kernel_size, dilation=1, padding=0, stride=1: -1,
838
torch.nn.functional.fractional_max_pool2d: (lambda input, kernel_size, output_size=None, output_ratio=None,
839
return_indices=False, _random_samples=None: -1),
840
torch.nn.functional.fractional_max_pool2d_with_indices: (
841
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
842
_random_samples=None: -1),
843
torch.nn.functional.fractional_max_pool3d: (lambda input, kernel_size, output_size=None, output_ratio=None,
844
return_indices=False, _random_samples=None: -1),
845
torch.nn.functional.fractional_max_pool3d_with_indices: (
846
lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
847
_random_samples=None: -1),
848
torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction='mean': -1,
849
torch.nn.functional.gelu: lambda input, approximate='none': -1,
850
torch.nn.functional.glu: lambda input, dim=-1: -1,
851
torch.nn.functional.grid_sample: lambda input, grid, mode='bilinear', padding_mode='zeros', align_corners=None: -1,
852
torch.nn.functional.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1,
853
torch.nn.functional.gumbel_softmax: lambda logits, tau=1, hard=False, eps=1e-10, dim=-1: -1,
854
torch.nn.functional.hardshrink: lambda input, lambd=0.5: -1,
855
torch.nn.functional.hardtanh: lambda input, min_val=-1., max_val=1., inplace=False: -1,
856
torch.nn.functional.hinge_embedding_loss: (lambda input, target, margin=1.0, size_average=None, reduce=None,
857
reduction='mean': -1),
858
torch.nn.functional.instance_norm: (lambda input, running_mean=None, running_var=None, weight=None, bias=None,
859
use_input_stats=True, momentum=0.1, eps=1e-05: -1),
860
torch.nn.functional.interpolate: (lambda input, size=None, scale_factor=None, mode='nearest', align_corners=None,
861
recompute_scale_factor=None, antialias=False: -1),
862
torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1,
863
torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
864
torch.nn.functional.layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
865
torch.nn.functional.leaky_relu: lambda input, negative_slope=0.01, inplace=False: -1,
866
torch.nn.functional.linear: lambda input, weight, bias=None: -1,
867
torch.nn.functional.local_response_norm: lambda input, size, alpha=0.0001, beta=0.75, k=1.0: -1,
868
torch.nn.functional.log_softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
869
torch.nn.functional.logsigmoid: lambda input: -1,
870
torch.nn.functional.lp_pool1d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
871
torch.nn.functional.lp_pool2d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
872
torch.nn.functional.lp_pool3d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
873
torch.nn.functional.margin_ranking_loss: (lambda input1, input2, target, margin=0, size_average=None,
874
reduce=None, reduction='mean': -1),
875
torch.nn.functional.max_pool1d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
876
ceil_mode=False, return_indices=False: -1),
877
torch.nn.functional.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
878
return_indices=False, ceil_mode=False: -1),
879
torch.nn.functional.max_pool2d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
880
ceil_mode=False, return_indices=False: -1),
881
torch.nn.functional.max_pool2d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
882
return_indices=False, ceil_mode=False: -1),
883
torch.nn.functional.max_pool3d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
884
return_indices=False, ceil_mode=False: -1),
885
torch.nn.functional.max_pool3d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
886
return_indices=False, ceil_mode=False: -1),
887
torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
888
torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
889
torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
890
torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
891
torch.nn.functional.multi_head_attention_forward: (
892
lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v,
893
add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None,
894
need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None,
895
v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1),
896
torch.nn.functional.multi_margin_loss: (lambda input, target, p=1, margin=1.0, weight=None, size_average=None,
897
reduce=None, reduction='mean': -1),
898
torch.nn.functional.multilabel_margin_loss: (lambda input, target, size_average=None, reduce=None,
899
reduction='mean': -1),
900
torch.nn.functional.multilabel_soft_margin_loss: (lambda input, target, weight=None, size_average=None,
901
reduce=None, reduction='mean': -1),
902
torch.nn.functional.nll_loss: (lambda input, target, weight=None, size_average=None, ignore_index=-100,
903
reduce=None, reduction='mean': -1),
904
torch.nn.functional.normalize: lambda input, p=2, dim=1, eps=1e-12, out=None: -1,
905
torch.nn.functional.one_hot: lambda tensor, num_classes=-1: -1,
906
torch.nn.functional.pad: lambda input, pad, mode='constant', value=0: -1,
907
torch.nn.functional.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1,
908
torch.nn.functional.poisson_nll_loss: (lambda input, target, log_input=True, full=False, size_average=None,
909
eps=1e-08, reduce=None, reduction='mean': -1),
910
torch.nn.functional.prelu: lambda input, weight: -1,
911
torch.nn.functional.relu: lambda input, inplace=False: -1,
912
torch.nn.functional.relu6: lambda input, inplace=False: -1,
913
torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1,
914
torch.nn.functional.selu: lambda input, inplace=False: -1,
915
torch.nn.functional.silu: lambda input, inplace=False: -1,
916
torch.nn.functional.mish: lambda input, inplace=False: -1,
917
torch.nn.functional.scaled_dot_product_attention: lambda query, key, value, attn_mask=None, dropout_p=0.0: -1,
918
torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean', beta=1.: -1,
919
torch.nn.functional.huber_loss: lambda input, target, reduction='mean', delta=1.: -1,
920
torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
921
torch.nn.functional.softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
922
torch.nn.functional.softmin: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
923
torch.nn.functional.softplus: lambda input, beta=1, threshold=20: -1,
924
torch.nn.functional.softshrink: lambda input, lambd=0.5: -1,
925
torch.nn.functional.softsign: lambda input: -1,
926
torch.nn.functional.tanhshrink: lambda input: -1,
927
torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1,
928
torch.nn.functional.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06,
929
swap=False, size_average=None, reduce=None, reduction='mean': -1),
930
torch.nn.functional.triplet_margin_with_distance_loss: (lambda anchor, positive, negative, *,
931
distance_function=None, margin=1.0,
932
swap=False, reduction='mean': -1),
933
torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1,
934
torch.nn.init.uniform_: lambda tensor, a=0., b=1., generator=None: -1,
935
torch.nn.init.normal_: lambda tensor, mean=0., std=1., generator=None: -1,
936
torch.nn.init.constant_: lambda tensor, val: -1,
937
torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', generator=None: -1,
938
torch.nonzero: lambda input, as_tuple=False: -1,
939
torch.nonzero_static: lambda input, *, size, fill_value=-1: -1,
940
torch.argwhere: lambda input: -1,
941
torch.norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
942
torch.linalg.norm: lambda input, ord=None, dim=None, keepdim=False, out=None, dtype=None: -1,
943
torch.linalg.vector_norm: lambda input, ord=2, dim=None, keepdim=False, out=None, dtype=None: -1,
944
torch.linalg.matrix_norm: lambda input, ord='fro', dim=(-2, -1), keepdim=False, out=None, dtype=None: -1,
945
torch.norm_except_dim: lambda v, pow=2, dim=0: -1,
946
torch.nuclear_norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
947
torch.numel: lambda input: -1,
948
torch.orgqr: lambda input, tau: -1,
949
torch.ormqr: lambda input, input2, input3, left=True, transpose=False: -1,
950
torch.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1,
951
torch.permute: lambda self, dim: -1,
952
torch.pca_lowrank: lambda input, q=None, center=True, niter=2: -1,
953
torch.pdist: lambda input, p=2: -1,
954
torch.pinverse: lambda input, rcond=1e-15: -1,
955
torch.linalg.pinv: lambda input, rcond=1e-15, hermitian=False: -1,
956
torch.pixel_shuffle: lambda input, upscale_factor: -1,
957
torch.pixel_unshuffle: lambda input, downscale_factor: -1,
958
torch.poisson: lambda input, generator=None: -1,
959
torch.poisson_nll_loss: lambda input, target, log_input, full, eps, reduction: -1,
960
torch.polygamma: lambda input, n, out=None: -1,
961
torch.positive: lambda input, out=None: -1,
962
torch.prelu: lambda input, weight: -1,
963
torch.ones_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
964
torch.pow: lambda input, exponent, out=None: -1,
965
torch.prod: lambda input, dtype=None: -1,
966
torch.put: lambda input, index, source, accumulate=False: -1,
967
torch.q_per_channel_axis: lambda input: -1,
968
torch.q_per_channel_scales: lambda input: -1,
969
torch.q_per_channel_zero_points: lambda input: -1,
970
torch.q_scale: lambda input: -1,
971
torch.q_zero_point: lambda input: -1,
972
torch.qr: lambda input, some=True, out=None: -1,
973
torch.linalg.qr: lambda input, mode='reduced', out=None: -1,
974
torch.quantile: lambda input, q, dim=None, keepdim=False, interpolation='linear', out=None: -1,
975
torch.nanquantile: lambda input, q, dim=None, keepdim=False, interpolation='linear', out=None: -1,
976
torch.quantize_per_channel: lambda input, scales, zero_points, axis, dtype: -1,
977
torch.quantize_per_tensor: lambda input, scale, zero_point, dtype: -1,
978
torch.quantize_per_tensor_dynamic: lambda input, dtype, reduce_range: -1,
979
torch.quantized_batch_norm: lambda input, weight, bias, mean, var, eps, output_scale, output_zero_point: -1,
980
torch.quantized_gru_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
981
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
983
torch.quantized_lstm_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
984
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
985
torch.quantized_max_pool1d: (lambda input, kernel_size, stride=tuple(), padding=(0,),
986
dilation=(1,), ceil_mode=False: -1),
987
torch.quantized_max_pool2d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0),
988
dilation=(1, 1), ceil_mode=False: -1),
989
torch.quantized_max_pool3d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0, 0),
990
dilation=(1, 1, 1), ceil_mode=False: -1),
991
torch.quantized_rnn_relu_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
992
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
993
torch.quantized_rnn_tanh_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
994
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
995
torch.rad2deg: lambda input, out=None: -1,
996
torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
997
torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
998
torch.randn_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
999
torch.ravel: lambda input: -1,
1000
torch.real: lambda input, out=None: -1,
1001
torch.vdot: lambda input, other, out=None: -1,
1002
torch.linalg.vecdot: lambda input, other, dim=-1, out=None: -1,
1003
torch.view_as_real: lambda input: -1,
1004
torch.view_as_complex: lambda input: -1,
1005
torch.reciprocal: lambda input, out=None: -1,
1006
torch.relu: lambda input, inplace=False: -1,
1007
torch.remainder: lambda input, other, out=None: -1,
1008
torch.renorm: lambda input, p, dim, maxnorm, out=None: -1,
1009
torch.repeat_interleave: lambda input, dim=None: -1,
1010
torch.reshape: lambda input, shape: -1,
1011
torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
1012
torch.rnn_relu_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
1013
torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
1014
torch.rnn_tanh_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
1015
torch.roll: lambda input, shifts, dims=None: -1,
1016
torch.rot90: lambda input, k=1, dims=(0, 1): -1,
1017
torch.round: lambda input, out=None: -1,
1018
torch.row_stack: lambda tensors, out=None: -1,
1019
torch._rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1),
1020
torch.rrelu: lambda input, lower=1. / 8, upper=1. / 3, training=False, inplace=False: -1,
1021
torch.rsqrt: lambda input, out=None: -1,
1022
torch.rsub: lambda input, other, alpha=1: -1,
1023
torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
1024
torch.scatter: lambda input, dim, index, src: -1,
1025
torch.scatter_add: lambda input, dim, index, src: -1,
1026
torch.scatter_reduce: lambda input, dim, index, src, reduce, include_self=True: -1,
1027
torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
1028
torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1,
1029
torch.select: lambda input, dim, index: -1,
1030
torch.select_scatter: lambda input, src, dim, index: -1,
1031
torch.slice_inverse: lambda input, src, dim=0, start=None, end=None, step=1: -1,
1032
torch.slice_scatter: lambda input, src, dim=0, start=None, end=None, step=1: -1,
1033
torch.selu: lambda input, inplace=False: -1,
1034
torch.sigmoid: lambda input, out=None: -1,
1035
torch.sign: lambda input, out=None: -1,
1036
torch.signbit: lambda input, out=None: -1,
1037
torch.sgn: lambda input, out=None: -1,
1038
torch.sin: lambda input, out=None: -1,
1039
torch.sinc: lambda input, out=None: -1,
1040
torch.sinh: lambda input, out=None: -1,
1041
torch.slogdet: lambda input: -1,
1042
torch.linalg.slogdet: lambda input: -1,
1043
torch.smm: lambda input, mat2: -1,
1044
torch.spmm: lambda input, mat2: -1,
1045
torch.softmax: lambda input, dim, dtype=None: -1,
1046
torch.linalg.solve: lambda A, B, left=True, out=None: -1,
1047
torch.linalg.solve_ex: lambda A, B, left=True, check_errors=False, out=None: -1,
1048
torch.sort: lambda input, dim=-1, descending=False, *, stable=False, out=None: -1,
1049
torch.split: lambda tensor, split_size_or_sections, dim=0: -1,
1050
torch.split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1,
1051
torch.sqrt: lambda input, out=None: -1,
1052
torch.square: lambda input, out=None: -1,
1053
torch.squeeze: lambda input, dim=None, out=None: -1,
1054
torch.sspaddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
1055
torch.stack: lambda tensors, dim=0, out=None: -1,
1056
torch.std: lambda input, dim=None: -1,
1057
torch.std_mean: lambda input, dim=None: -1,
1058
torch.stft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True,
1059
pad_mode='reflect', normalized=False, onesided=True, return_complex=None: -1),
1060
torch.sub: lambda input, other, out=None: -1,
1061
torch.subtract: lambda input, other, out=None: -1,
1062
torch.sum: lambda input, dim=None: -1,
1063
torch.sym_float: lambda input: -1,
1064
torch.sym_int: lambda input: -1,
1065
torch.sym_max: lambda a, b: -1,
1066
torch.sym_min: lambda a, b: -1,
1067
torch.sym_not: lambda input: -1,
1068
torch.sym_ite: lambda a, b, c: -1,
1069
torch._sym_sqrt: lambda input: -1,
1070
torch._sym_cos: lambda input: -1,
1071
torch._sym_cosh: lambda input: -1,
1072
torch._sym_sin: lambda input: -1,
1073
torch._sym_sinh: lambda input: -1,
1074
torch._sym_tan: lambda input: -1,
1075
torch._sym_tanh: lambda input: -1,
1076
torch._sym_asin: lambda input: -1,
1077
torch._sym_acos: lambda input: -1,
1078
torch._sym_atan: lambda input: -1,
1079
torch.nansum: lambda input, dim=None: -1,
1080
torch.svd: lambda input, some=True, compute_uv=True, out=None: -1,
1081
torch.svd_lowrank: lambda input, q=6, niter=2, M=None: -1,
1082
torch.linalg.svd: lambda input, full_matrices=True, out=None: -1,
1083
torch.linalg.svdvals: lambda input, out=None: -1,
1084
torch.swapaxes: lambda input, dim0, dim1: -1,
1085
torch.swapdims: lambda input, axis0, axis1: -1,
1086
torch.special.airy_ai: lambda input: -1,
1087
torch.special.bessel_j0: lambda input: -1,
1088
torch.special.bessel_j1: lambda input: -1,
1089
torch.special.bessel_y0: lambda input: -1,
1090
torch.special.bessel_y1: lambda input: -1,
1091
torch.special.chebyshev_polynomial_t: lambda input, n, out=None: -1,
1092
torch.special.chebyshev_polynomial_u: lambda input, n, out=None: -1,
1093
torch.special.chebyshev_polynomial_v: lambda input, n, out=None: -1,
1094
torch.special.chebyshev_polynomial_w: lambda input, n, out=None: -1,
1095
torch.special.digamma: lambda input: -1,
1096
torch.special.entr: lambda input: -1,
1097
torch.special.erf: lambda input: -1,
1098
torch.special.erfc: lambda input: -1,
1099
torch.special.erfcx: lambda input: -1,
1100
torch.special.erfinv: lambda input: -1,
1101
torch.special.exp2: lambda input: -1,
1102
torch.special.expit: lambda input: -1,
1103
torch.special.expm1: lambda input: -1,
1104
torch.special.gammainc: lambda input, other, out=None: -1,
1105
torch.special.gammaincc: lambda input, other, out=None: -1,
1106
torch.special.gammaln: lambda input: -1,
1107
torch.special.hermite_polynomial_h: lambda input, n, out=None: -1,
1108
torch.special.hermite_polynomial_he: lambda input, n, out=None: -1,
1109
torch.special.i0: lambda input: -1,
1110
torch.special.i0e: lambda input: -1,
1111
torch.special.i1: lambda input: -1,
1112
torch.special.i1e: lambda input: -1,
1113
torch.special.laguerre_polynomial_l: lambda input, n, out=None: -1,
1114
torch.special.legendre_polynomial_p: lambda input, n, out=None: -1,
1115
torch.special.log1p: lambda input: -1,
1116
torch.special.log_ndtr: lambda input: -1,
1117
torch.special.log_softmax: lambda input, dim, dtype=None: -1,
1118
torch.special.logit: lambda input: -1,
1119
torch.special.logsumexp: lambda input, dim, keepdim=False, out=None: -1,
1120
torch.special.modified_bessel_i0: lambda input: -1,
1121
torch.special.modified_bessel_i1: lambda input: -1,
1122
torch.special.modified_bessel_k0: lambda input: -1,
1123
torch.special.modified_bessel_k1: lambda input: -1,
1124
torch.special.multigammaln: lambda input, p: -1,
1125
torch.special.ndtr: lambda input: -1,
1126
torch.special.ndtri: lambda input: -1,
1127
torch.special.polygamma: lambda input, n, out=None: -1,
1128
torch.special.psi: lambda input: -1,
1129
torch.special.round: lambda input: -1,
1130
torch.special.scaled_modified_bessel_k0: lambda input: -1,
1131
torch.special.scaled_modified_bessel_k1: lambda input: -1,
1132
torch.special.shifted_chebyshev_polynomial_t: lambda input, n, out=None: -1,
1133
torch.special.shifted_chebyshev_polynomial_u: lambda input, n, out=None: -1,
1134
torch.special.shifted_chebyshev_polynomial_v: lambda input, n, out=None: -1,
1135
torch.special.shifted_chebyshev_polynomial_w: lambda input, n, out=None: -1,
1136
torch.special.sinc: lambda input: -1,
1137
torch.special.softmax: lambda input, dim, dtype=None: -1,
1138
torch.special.spherical_bessel_j0: lambda input: -1,
1139
torch.special.xlog1py: lambda input, other, out=None: -1,
1140
torch.special.xlogy: lambda input, other, out=None: -1,
1141
torch.special.zeta: lambda self, other, out=None: -1,
1142
torch.t: lambda input: -1,
1143
torch.take: lambda input, index: -1,
1144
torch.take_along_dim: lambda input, indices, dim=None, out=None: -1,
1145
torch.tan: lambda input, out=None: -1,
1146
torch.tanh: lambda input, out=None: -1,
1147
torch.linalg.tensorinv: lambda a, ind=2: -1,
1148
torch.linalg.tensorsolve: lambda a, b, dims=None: -1,
1149
torch.tensordot: lambda a, b, dims=2, out=None: -1,
1150
torch.tensor_split: lambda input, indices_or_sections, dim=0: -1,
1151
torch.threshold: lambda input, threshold, value, inplace=False: -1,
1152
torch.tile: lambda input, dims: -1,
1153
torch.topk: lambda input, k, dim=-1, descending=False, out=None: -1,
1154
torch.trace: lambda input: -1,
1155
torch.transpose: lambda input, dim0, dim1: -1,
1156
torch.trapz: lambda y, x=None, dim=-1: -1,
1157
torch.trapezoid: lambda y, x=None, dim=-1: -1,
1158
torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1,
1159
torch.linalg.solve_triangular: lambda input, B, upper, left=True, unitriangular=False: -1,
1160
torch.tril: lambda input, diagonal=0, out=None: -1,
1161
torch.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False,
1163
size_average=None, reduce=None, reduction='mean': -1),
1164
torch.triu: lambda input, diagonal=0, out=None: -1,
1165
torch.true_divide: lambda input, other: -1,
1166
torch.trunc: lambda input, out=None: -1,
1167
torch.unbind: lambda input, dim=0: -1,
1168
torch.unflatten: lambda input, dim, sizes, names: -1,
1169
torch.unique: lambda input, sorted=True, return_inverse=False, return_counts=False, dim=None: -1,
1170
torch.unique_consecutive: lambda input, return_inverse=False, return_counts=False, dim=None: -1,
1171
torch.unravel_index: lambda indices, shape: -1,
1172
torch.unsafe_chunk: lambda input, chunks, dim=0: -1,
1173
torch.unsafe_split: lambda tensor, split_size_or_sections, dim=0: -1,
1174
torch.unsafe_split_with_sizes: lambda tensor, split_size_or_sections, dim=0: -1,
1175
torch.unsqueeze: lambda input, dim, out=None: -1,
1176
torch.linalg.vander: lambda x, N=None: -1,
1177
torch.var: lambda input, dim=None: -1,
1178
torch.var_mean: lambda input, dim=None: -1,
1179
torch.vsplit: lambda input, indices_or_sections: -1,
1180
torch.vstack: lambda tensors, out=None: -1,
1181
torch.where: lambda condition, x=None, y=None: -1,
1182
torch.zeros_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
1183
torch._fw_primal_copy: lambda self, level: -1,
1184
torch._make_dual_copy: lambda primal, tangent, level: -1,
1185
torch.view_as_real_copy: lambda self: -1,
1186
torch.view_as_complex_copy: lambda self: -1,
1187
torch._conj_copy: lambda self: -1,
1188
torch._neg_view_copy: lambda self: -1,
1189
torch.as_strided_copy: lambda self, size, stride, storage_offset=None: -1,
1190
torch._sparse_broadcast_to_copy: lambda self, size: -1,
1191
torch.diagonal_copy: lambda self, offset=0, dim1=0, dim2=1: -1,
1192
torch.expand_copy: lambda self, size, *, implicit=False: -1,
1193
torch.narrow_copy: lambda self, dim, start, length: -1,
1194
torch.permute_copy: lambda self, dims: -1,
1195
torch._reshape_alias_copy: lambda self, size, stride: -1,
1196
torch.select_copy: lambda self, dim, index: -1,
1197
torch.detach_copy: lambda self: -1,
1198
torch.slice_copy: lambda self, dim=0, start=None, end=None, step=1: -1,
1199
torch.split_copy: lambda self, split_size, dim=0: -1,
1200
torch.split_with_sizes_copy: lambda self, split_sizes, dim=0: -1,
1201
torch.squeeze_copy: lambda self, dim: -1,
1202
torch.t_copy: lambda self: -1,
1203
torch.transpose_copy: lambda self, dim0, dim1: -1,
1204
torch.unsqueeze_copy: lambda self, dim: -1,
1205
torch._indices_copy: lambda self: -1,
1206
torch._values_copy: lambda self: -1,
1207
torch.indices_copy: lambda self: -1,
1208
torch.values_copy: lambda self: -1,
1209
torch.crow_indices_copy: lambda self: -1,
1210
torch.col_indices_copy: lambda self: -1,
1211
torch.ccol_indices_copy: lambda self: -1,
1212
torch.row_indices_copy: lambda self: -1,
1213
torch.unbind_copy: lambda self, dim=0: -1,
1214
torch.view_copy: lambda self, dtype: -1,
1215
torch.unfold_copy: lambda self, dimension, size, step: -1,
1216
torch.alias_copy: lambda self: -1,
1217
Tensor.__floordiv__: lambda self, other: -1,
1218
Tensor.__rfloordiv__: lambda self, other: -1,
1219
Tensor.__ifloordiv__: lambda self, other: -1,
1220
Tensor.__truediv__: lambda self, other: -1,
1221
Tensor.__rtruediv__: lambda self, other: -1,
1222
Tensor.__itruediv__: lambda self, other: -1,
1223
Tensor.__lshift__: lambda self, other: -1,
1224
Tensor.__rlshift__: lambda self, other: -1,
1225
Tensor.__ilshift__: lambda self, other: -1,
1226
Tensor.__rshift__: lambda self, other: -1,
1227
Tensor.__rrshift__: lambda self, other: -1,
1228
Tensor.__irshift__: lambda self, other: -1,
1229
Tensor.__and__: lambda self, other: -1,
1230
Tensor.__or__: lambda self, other: -1,
1231
Tensor.__xor__: lambda self, other: -1,
1232
Tensor.__float__: lambda self: -1,
1233
Tensor.__complex__: lambda self: -1,
1234
Tensor.__array__: lambda self, dtype: -1,
1235
Tensor.__bool__: lambda self: -1,
1236
Tensor.__contains__: lambda self, other: -1,
1237
Tensor.__neg__: lambda self: -1,
1238
Tensor.__invert__: lambda self: -1,
1239
Tensor.__mod__: lambda self, other: -1,
1240
Tensor.__rmod__: lambda self, other: -1,
1241
Tensor.__imod__: lambda self, other: -1,
1242
Tensor.__array_wrap__: lambda self, array: -1,
1243
Tensor.__getitem__: lambda self, idx: -1,
1244
Tensor.__deepcopy__: lambda self, memo: -1,
1245
Tensor.__int__: lambda self: -1,
1246
Tensor.__long__: lambda self: -1,
1247
Tensor.__index__: lambda self: -1,
1248
Tensor.__len__: lambda self: -1,
1249
Tensor.__format__: lambda self, format_spec: -1,
1250
Tensor.__reduce_ex__: lambda self, proto: -1,
1251
Tensor.__reversed__: lambda self: -1,
1252
Tensor.__repr__: lambda self, *, tensor_contents=None: -1,
1253
Tensor.__setitem__: lambda self, k, v: -1,
1254
Tensor.__setstate__: lambda self, d: -1,
1255
Tensor.T.__get__: lambda self: -1,
1256
Tensor.H.__get__: lambda self: -1,
1257
Tensor.mT.__get__: lambda self: -1,
1258
Tensor.mH.__get__: lambda self: -1,
1259
Tensor._backward_hooks.__get__: lambda self: -1,
1260
Tensor._post_accumulate_grad_hooks.__get__: lambda self: -1,
1261
Tensor._base.__get__: lambda self: -1,
1262
Tensor._cdata.__get__: lambda self: -1,
1263
Tensor.grad.__get__: lambda self: -1,
1264
Tensor._grad.__get__: lambda self: -1,
1265
Tensor._grad_fn.__get__: lambda self: -1,
1266
Tensor.grad_fn.__get__: lambda self: -1,
1267
Tensor._version.__get__: lambda self: -1,
1268
Tensor._autocast_to_reduced_precision: lambda self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype: -1,
1269
Tensor._autocast_to_full_precision: lambda self, cuda_enabled, cpu_enabled: -1,
1270
Tensor.data.__get__: lambda self: -1,
1271
Tensor.device.__get__: lambda self: -1,
1272
Tensor.dtype.__get__: lambda self: -1,
1273
Tensor.is_cuda.__get__: lambda self: -1,
1274
Tensor.is_cpu.__get__: lambda self: -1,
1275
Tensor.is_xla.__get__: lambda self: -1,
1276
Tensor.is_xpu.__get__: lambda self: -1,
1277
Tensor.is_ipu.__get__: lambda self: -1,
1278
Tensor.is_leaf.__get__: lambda self: -1,
1279
Tensor.retains_grad.__get__: lambda self: -1,
1280
Tensor.is_meta.__get__: lambda self: -1,
1281
Tensor.is_mps.__get__: lambda self: -1,
1282
Tensor.is_mtia.__get__: lambda self: -1,
1283
Tensor.is_nested.__get__: lambda self: -1,
1284
Tensor.is_ort.__get__: lambda self: -1,
1285
Tensor.is_mkldnn.__get__: lambda self: -1,
1286
Tensor.is_quantized.__get__: lambda self: -1,
1287
Tensor.is_sparse.__get__: lambda self: -1,
1288
Tensor.is_sparse_csr.__get__: lambda self: -1,
1289
Tensor.is_vulkan.__get__: lambda self: -1,
1290
Tensor.itemsize.__get__: lambda self: -1,
1291
Tensor.layout.__get__: lambda self: -1,
1292
Tensor.name.__get__: lambda self: -1,
1293
Tensor.names.__get__: lambda self: -1,
1294
Tensor.nbytes.__get__: lambda self: -1,
1295
Tensor.ndim.__get__: lambda self: -1,
1296
Tensor.output_nr.__get__: lambda self: -1,
1297
Tensor.requires_grad.__get__: lambda self: -1,
1298
Tensor.shape.__get__: lambda self: -1,
1299
Tensor.volatile.__get__: lambda self: -1,
1300
Tensor.real.__get__: lambda self: -1,
1301
Tensor.imag.__get__: lambda self: -1,
1302
Tensor.__cuda_array_interface__.__get__: lambda self: -1,
1303
Tensor.type: lambda self, dtype=None, non_blocking=False, **kwargs: -1,
1304
Tensor._dimI: lambda self: -1,
1305
Tensor._dimV: lambda self: -1,
1306
Tensor._indices: lambda self: -1,
1307
Tensor._is_view: lambda self: -1,
1308
Tensor._nnz: lambda self: -1,
1309
Tensor.crow_indices: lambda self: -1,
1310
Tensor.col_indices: lambda self: -1,
1311
Tensor.ccol_indices: lambda self: -1,
1312
Tensor.row_indices: lambda self: -1,
1313
Tensor._update_names: lambda self, names, inplace: -1,
1314
Tensor._values: lambda self: -1,
1315
Tensor.adjoint: lambda self: -1,
1316
Tensor.align_as: lambda self, other: -1,
1317
Tensor.align_to: lambda self, order, ellipsis_idx: -1,
1318
Tensor.apply_: lambda self, callable: -1,
1319
Tensor.as_strided: lambda self, size, stride: -1,
1320
Tensor.as_strided_: lambda self, size, stride: -1,
1321
Tensor.backward: lambda self, gradient=None, retain_graph=None, create_graph=False, inputs=None: -1,
1322
Tensor.bfloat16: lambda self, memory_format=torch.preserve_format: -1,
1323
Tensor.bool: lambda self, memory_format=torch.preserve_format: -1,
1324
Tensor.byte: lambda self, memory_format=torch.preserve_format: -1,
1325
Tensor.char: lambda self, memory_format=torch.preserve_format: -1,
1326
Tensor.cauchy_: lambda self, median=0, sigma=1, *, generator=None: -1,
1327
Tensor.coalesce: lambda self: -1,
1328
Tensor._coalesced_: lambda self, coalesced: -1,
1329
Tensor.contiguous: lambda self, memory_format=torch.contiguous_format: -1,
1330
Tensor.copy_: lambda self, src, non_blocking=False: -1,
1331
Tensor.cpu: lambda self, memory_format=torch.preserve_format: -1,
1332
Tensor.cuda: lambda self, memory_format=torch.preserve_format: -1,
1333
Tensor.xpu: lambda self, memory_format=torch.preserve_format: -1,
1334
Tensor.ipu: lambda self, memory_format=torch.preserve_format: -1,
1335
Tensor.data_ptr: lambda self: -1,
1336
Tensor.dense_dim: lambda self: -1,
1337
Tensor.diagonal_scatter: lambda self, src, offset=0, dim1=0, dim2=1: -1,
1338
Tensor.dim: lambda self: -1,
1339
Tensor.dim_order: lambda self: -1,
1340
Tensor.double: lambda self, memory_format=torch.preserve_format: -1,
1341
Tensor.cdouble: lambda self, memory_format=torch.preserve_format: -1,
1342
Tensor.element_size: lambda self: -1,
1343
Tensor.expand: lambda self, size: -1,
1344
Tensor.expand_as: lambda self, other: -1,
1345
Tensor.exponential_: lambda self, lambd=1, *, generator=None: -1,
1346
Tensor.fill_: lambda self, value: -1,
1347
Tensor.fill_diagonal_: lambda self, value: -1,
1348
Tensor.float: lambda self, memory_format=torch.preserve_format: -1,
1349
Tensor.cfloat: lambda self, memory_format=torch.preserve_format: -1,
1350
Tensor.geometric_: lambda self, p, *, generator=None: -1,
1351
Tensor.get_device: lambda self: -1,
1352
Tensor.half: lambda self, memory_format=torch.preserve_format: -1,
1353
Tensor.chalf: lambda self, memory_format=torch.preserve_format: -1,
1354
Tensor.has_names: lambda self: -1,
1355
Tensor.indices: lambda self: -1,
1356
Tensor.int: lambda self, memory_format=torch.preserve_format: -1,
1357
Tensor.is_coalesced: lambda self: -1,
1358
Tensor.is_contiguous: lambda self: -1,
1359
Tensor.is_inference: lambda self: -1,
1360
Tensor.is_pinned: lambda self: -1,
1361
Tensor.is_set_to: lambda self, tensor: -1,
1362
Tensor.is_shared: lambda self: -1,
1363
Tensor.item: lambda self: -1,
1364
Tensor.log_normal_: lambda self, mean=1, std=2, *, generator=None: -1,
1365
Tensor.log_softmax: lambda self, dim: -1,
1366
Tensor.long: lambda self, memory_format=torch.preserve_format: -1,
1367
Tensor.map_: lambda self, tensor, callable: -1,
1368
Tensor.map2_: lambda self, x, y, callable: -1,
1369
Tensor.mm: lambda self, mat2: -1,
1370
Tensor.module_load: lambda self, other: -1,
1371
Tensor.narrow_copy: lambda self, dimension, start, length: -1,
1372
Tensor.ndimension: lambda self: -1,
1373
Tensor.nelement: lambda self: -1,
1374
Tensor._nested_tensor_size: lambda self: -1,
1375
Tensor._nested_tensor_storage_offsets: lambda self: -1,
1376
Tensor._nested_tensor_strides: lambda self: -1,
1377
Tensor.normal_: lambda self: -1,
1378
Tensor.numpy: lambda self: -1,
1379
Tensor.permute: lambda self, dim: -1,
1380
Tensor.pin_memory: lambda self: -1,
1381
Tensor.put_: lambda self, indices, tensor, accumulate=False: -1,
1382
Tensor.qscheme: lambda self: -1,
1383
Tensor.random_: lambda self, from_=0, to=None, *, generator=None: -1,
1384
Tensor.record_stream: lambda self, stream: -1,
1385
Tensor.refine_names: lambda self, names: -1,
1386
Tensor.register_hook: lambda self, hook: -1,
1387
Tensor.register_post_accumulate_grad_hook: lambda self, hook: -1,
1388
Tensor.rename: lambda self, name: -1,
1389
Tensor.repeat: lambda self, *size: -1,
1390
Tensor.requires_grad_: lambda self, requires_grad=True: -1,
1391
Tensor.reshape_as: lambda self, other: -1,
1392
Tensor.resize: lambda self, *size: -1,
1393
Tensor.resize_: lambda self, size: -1,
1394
Tensor.resize_as: lambda self, other: -1,
1395
Tensor.resize_as_sparse_: lambda self, other: -1,
1396
Tensor.retain_grad: lambda self: -1,
1397
Tensor.set_: lambda self, source=None, storage_offset=0, size=None, stride=None: -1,
1398
Tensor.select_scatter: lambda self, src, dim, index: -1,
1399
Tensor.share_memory_: lambda self: -1,
1400
Tensor.short: lambda self, memory_format=torch.preserve_format: -1,
1401
Tensor.size: lambda self: -1,
1402
Tensor.slice_scatter: lambda self, src, dim=0, start=None, end=None, step=1: -1,
1403
Tensor.sparse_dim: lambda self: -1,
1404
Tensor.sparse_mask: lambda self, mask: -1,
1405
Tensor._sparse_mask_projection: lambda self, mask, accumulate_matches=False: -1,
1406
Tensor.sparse_resize_: lambda self, size1, size2, dense_dim: -1,
1407
Tensor.sparse_resize_and_clear_: lambda self, size1, size2, dense_dim: -1,
1408
Tensor.sspaddmm: lambda self, mat1, mat2, beta=1, alpha=1, out=None: -1,
1409
Tensor.storage: lambda self: -1,
1410
Tensor.untyped_storage: lambda self: -1,
1411
Tensor.storage_offset: lambda self: -1,
1412
Tensor.storage_type: lambda self: -1,
1413
Tensor.sum_to_size: lambda self, size: -1,
1414
Tensor.tile: lambda self, *reps: -1,
1415
Tensor.to: lambda self, dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format: -1,
1416
Tensor.to_dense: lambda self, dtype=None, *, masked_grad=None: -1,
1417
Tensor._to_dense: lambda self, dtype=None, masked_grad=None: -1,
1418
Tensor.to_sparse: lambda self: -1,
1419
Tensor.tolist: lambda self: -1,
1420
Tensor.to_mkldnn: lambda self: -1,
1421
Tensor.type_as: lambda self, other: -1,
1422
Tensor.unfold: lambda self, dimension, size, step: -1,
1423
Tensor.uniform_: lambda self, from_=0, to=1: -1,
1424
Tensor.values: lambda self: -1,
1425
Tensor.view: lambda self, shape: -1,
1426
Tensor.view_as: lambda self, other: -1,
1427
Tensor.zero_: lambda self: -1,
1428
Tensor.__dlpack__: lambda self, stream=None: -1,
1429
Tensor.__dlpack_device__: lambda self: -1,
1430
torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1,
1434
ignored = get_ignored_functions()
1436
for k, v in ret.items():
1441
"__" + k.__name__ + "__",
1442
"__i" + k.__name__ + "__",
1443
"__r" + k.__name__ + "__",
1446
if k.__name__.startswith("bitwise_"):
1449
subname = k.__name__[len("bitwise_"):]
1451
"__" + subname + "__",
1452
"__i" + subname + "__",
1453
"__r" + subname + "__"
1457
func = getattr(Tensor, name, None)
1458
if callable(func) and func not in ret and func not in ignored:
1464
def wrap_torch_function(dispatcher: Callable):
1465
"""Wraps a given function with ``__torch_function__`` -related functionality.
1469
dispatcher: Callable
1470
A callable that returns an iterable of Tensor-likes passed into the function.
1474
This decorator may reduce the performance of your code. Generally, it's enough to express
1475
your code as a series of functions that, themselves, support __torch_function__. If you
1476
find yourself in the rare situation where this is not the case, e.g. if you're wrapping a
1477
low-level library and you also need it to work for Tensor-likes, then this function is available.
1481
>>> def dispatcher(a): # Must have the same signature as func
1483
>>> @torch.overrides.wrap_torch_function(dispatcher)
1484
>>> def func(a): # This will make func dispatchable by __torch_function__
1488
@functools.wraps(func)
1489
def wrapped(*args, **kwargs):
1490
relevant_args = dispatcher(*args, **kwargs)
1491
if has_torch_function(relevant_args):
1492
return handle_torch_function(wrapped, relevant_args, *args, **kwargs)
1494
return func(*args, **kwargs)
1500
def _get_overloaded_args(relevant_args: Iterable[Any], get_type_fn: Callable[[Any], Type] = None) -> List[Any]:
1501
"""Returns a list of arguments on which to call __torch_function__.
1503
Checks arguments in relevant_args for __torch_function__ implementations,
1504
storing references to the arguments and their types in overloaded_args and
1505
overloaded_types in order of calling precedence. Only distinct types are
1506
considered. If a type is a subclass of another type it will have higher
1507
precedence, otherwise the precedence order is the same as the order of
1508
arguments in relevant_args, that is, from left-to-right in the argument list.
1510
The precedence-determining algorithm implemented in this function is
1511
described in `NEP-0018`_.
1513
See torch::append_overloaded_arg for the equivalent function in the C++
1518
relevant_args : iterable of array-like
1519
Iterable of array-like arguments to check for __torch_function__
1522
get_type_fn : callable, optional
1523
Function to call on each argument in relevant_args to get its type.
1527
overloaded_args : list
1528
Arguments from relevant_args on which to call __torch_function__
1529
methods, in the order in which they should be called.
1532
https://numpy.org/neps/nep-0018-array-function-protocol.html
1534
if get_type_fn is None:
1538
if not torch._C._is_torch_function_enabled():
1541
overloaded_types: Set[Type] = set()
1542
overloaded_args: List[Any] = []
1543
for arg in relevant_args:
1544
arg_type = get_type_fn(arg)
1551
if (arg_type not in overloaded_types and hasattr(arg_type, '__torch_function__') and
1552
arg_type.__torch_function__ != torch._C._disabled_torch_function_impl):
1555
if overloaded_types:
1556
overloaded_types.add(arg_type)
1560
index = len(overloaded_args)
1561
for i, old_arg in enumerate(overloaded_args):
1562
if issubclass(arg_type, get_type_fn(old_arg)):
1565
overloaded_args.insert(index, arg)
1567
overloaded_types = {arg_type}
1568
overloaded_args = [arg]
1569
return overloaded_args
1572
def handle_torch_function(
1573
public_api: Callable, relevant_args: Iterable[Any], *args, **kwargs) -> Any:
1574
"""Implement a function with checks for ``__torch_function__`` overrides.
1576
See torch::autograd::handle_torch_function for the equivalent of this
1577
function in the C++ implementation.
1581
public_api : function
1582
Function exposed by the public torch API originally called like
1583
``public_api(*args, **kwargs)`` on which arguments are now being
1585
relevant_args : iterable
1586
Iterable of arguments to check for __torch_function__ methods.
1588
Arbitrary positional arguments originally passed into ``public_api``.
1590
Arbitrary keyword arguments originally passed into ``public_api``.
1595
Result from calling ``implementation`` or an ``__torch_function__``
1596
method, as appropriate.
1600
TypeError : if no implementation is found.
1605
... if has_torch_function_unary(a):
1606
... return handle_torch_function(func, (a,), a)
1610
overloaded_args = _get_overloaded_args(relevant_args)
1612
types = tuple(map(type, overloaded_args))
1615
if _is_torch_function_mode_enabled():
1618
with _pop_mode_temporarily() as mode:
1619
result = mode.__torch_function__(public_api, types, args, kwargs)
1620
if result is not NotImplemented:
1624
for overloaded_arg in overloaded_args:
1627
torch_func_method = overloaded_arg.__torch_function__
1628
if hasattr(torch_func_method, "__self__") and torch_func_method.__self__ is overloaded_arg and \
1629
torch_func_method is not torch._C._disabled_torch_function_impl:
1630
warnings.warn("Defining your `__torch_function__ as a plain method is deprecated and "
1631
"will be an error in future, please define it as a classmethod.",
1636
result = torch_func_method(public_api, types, args, kwargs)
1638
if result is not NotImplemented:
1641
func_name = f'{public_api.__module__}.{public_api.__name__}'
1643
f"no implementation found for '{func_name}' on types that implement "
1644
f'__torch_function__: {[type(arg) for arg in overloaded_args]}'
1646
if _is_torch_function_mode_enabled():
1647
msg += f" nor in mode {_get_current_function_mode()}"
1648
raise TypeError(msg)
1650
has_torch_function = _add_docstr(
1651
_has_torch_function,
1652
r"""Check for __torch_function__ implementations in the elements of an iterable
1653
or if a __torch_function__ mode is enabled. Considers exact ``Tensor`` s
1654
and ``Parameter`` s non-dispatchable. Use this to guard a call to
1655
:func:`handle_torch_function`; don't use it to test if something
1656
is Tensor-like, use :func:`is_tensor_like` instead.
1659
relevant_args : iterable
1660
Iterable or arguments to check for __torch_function__ methods.
1664
True if any of the elements of relevant_args have __torch_function__
1665
implementations, False otherwise.
1668
torch.is_tensor_like
1669
Checks if something is a Tensor-like, including an exact ``Tensor``.
1673
has_torch_function_unary = _add_docstr(
1674
_has_torch_function_unary,
1675
r"""Special case of `has_torch_function` for single inputs.
1677
`has_torch_function((t,))`
1679
`has_torch_function_unary(t)`
1680
which skips unnecessary packing and unpacking work.
1684
has_torch_function_variadic = _add_docstr(
1685
_has_torch_function_variadic,
1686
r"""Special case of `has_torch_function` that skips tuple creation.
1688
This uses the METH_FASTCALL protocol introduced in Python 3.7
1691
`has_torch_function((a, b))`
1693
`has_torch_function_variadic(a, b)`
1694
which skips unnecessary packing and unpacking work.
1698
@functools.lru_cache(None)
1699
def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callable, str]]:
1700
overridable_funcs = collections.defaultdict(list)
1702
tested_namespaces = [
1703
("torch", torch, torch.__all__),
1704
("torch.functional", torch.functional, torch.functional.__all__),
1705
("torch.nn.functional", torch.nn.functional, dir(torch.nn.functional)),
1706
("torch.nn.init", torch.nn.init, dir(torch.nn.init)),
1707
("torch.Tensor", torch.Tensor, dir(torch.Tensor)),
1708
("torch.linalg", torch.linalg, dir(torch.linalg)),
1709
("torch.fft", torch.fft, dir(torch.fft)),
1710
("torch.special", torch.special, dir(torch.special)),
1712
for namespace_str, namespace, ns_funcs in tested_namespaces:
1713
for func_name in ns_funcs:
1716
if namespace is not torch.Tensor:
1717
if func_name.startswith('__'):
1719
elif func_name.startswith('_'):
1721
elif func_name.endswith('_'):
1723
elif not func_name[0].islower():
1725
elif func_name == 'unique_dim':
1728
func = getattr(namespace, func_name)
1729
if getattr(object, func_name, None) == func:
1731
if func_name == '__weakref__':
1733
func = getattr(namespace, func_name)
1734
if namespace is torch.Tensor and getattr(object, func_name, None) == func:
1737
if isinstance(func, types.ModuleType):
1740
if isinstance(func, __future__._Feature):
1743
if not callable(func) and hasattr(func, "__get__"):
1744
index[func.__get__] = f"{namespace_str}.{func_name}.__get__"
1745
index[func.__set__] = f"{namespace_str}.{func_name}.__set__"
1748
if func.__get__ in get_ignored_functions():
1749
msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
1750
"but still has an explicit override")
1751
assert func.__get__ not in get_testing_overrides(), msg.format(namespace, func.__name__)
1754
overridable_funcs[func].append(func.__get__)
1757
if not callable(func):
1760
index[func] = f"{namespace_str}.{func_name}"
1766
if func in get_ignored_functions():
1767
msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
1768
"but still has an explicit override")
1769
assert func not in get_testing_overrides(), msg.format(namespace, func.__name__)
1771
overridable_funcs[namespace].append(func)
1772
return overridable_funcs, index
1774
@_disable_user_warnings
1775
def get_overridable_functions() -> Dict[Any, List[Callable]]:
1776
"""List functions that are overridable via __torch_function__
1780
Dict[Any, List[Callable]]
1781
A dictionary that maps namespaces that contain overridable functions
1782
to functions in that namespace that can be overridden.
1784
return _get_overridable_functions()[0]
1786
@_disable_user_warnings
1788
"""Get a human readable string name for a function passed to
1794
Function to resolve the name of.
1799
Name of the function; if eval'ed it should give back the input
1802
if isinstance(f, (torch._ops.OpOverload, torch._ops.OpOverloadPacket)):
1804
return _get_overridable_functions()[1].get(f)
1806
@functools.lru_cache(None)
1807
def _get_tensor_methods() -> Set[Callable]:
1808
""" Returns a set of the overridable methods on ``torch.Tensor`` """
1809
overridable_funcs = get_overridable_functions()
1810
methods = set(overridable_funcs[torch.Tensor])
1813
@_disable_user_warnings
1814
def is_tensor_method_or_property(func: Callable) -> bool:
1816
Returns True if the function passed in is a handler for a
1817
method or property belonging to ``torch.Tensor``, as passed
1818
into ``__torch_function__``.
1821
For properties, their ``__get__`` method must be passed in.
1823
This may be needed, in particular, for the following reasons:
1825
1. Methods/properties sometimes don't contain a `__module__` slot.
1826
2. They require that the first passed-in argument is an instance
1827
of ``torch.Tensor``.
1831
>>> is_tensor_method_or_property(torch.Tensor.add)
1833
>>> is_tensor_method_or_property(torch.add)
1836
return func in _get_tensor_methods() or func.__name__ == "__get__"
1838
def is_tensor_like(inp):
1840
Returns ``True`` if the passed-in input is a Tensor-like.
1842
Currently, this occurs whenever there's a ``__torch_function__``
1843
attribute on the type of the input.
1847
A subclass of tensor is generally a Tensor-like.
1849
>>> class SubTensor(torch.Tensor): ...
1850
>>> is_tensor_like(SubTensor([0]))
1853
Built-in or user types aren't usually Tensor-like.
1855
>>> is_tensor_like(6)
1857
>>> is_tensor_like(None)
1859
>>> class NotATensor: ...
1860
>>> is_tensor_like(NotATensor())
1863
But, they can be made Tensor-like by implementing __torch_function__.
1865
>>> class TensorLike:
1867
... def __torch_function__(cls, func, types, args, kwargs):
1869
>>> is_tensor_like(TensorLike())
1872
return type(inp) is torch.Tensor or hasattr(inp, "__torch_function__")
1874
class TorchFunctionMode:
1876
A ``TorchFunctionMode`` allows you to override the meaning of all
1877
``__torch_function__`` overrideable functions within a dynamic scope,
1878
without having to actually create a tensor subclass or manually
1879
monkey-patch functions in the PyTorch API. Some common situations
1880
where you should use a mode:
1882
* You want to override the meaning of factory functions, or other
1883
functions that do not otherwise take a tensor as an argument
1884
(these cannot be overridden with tensor subclasses).
1886
* You want to override the behavior of all functions without needing
1887
to wrap your inputs in tensor subclasses; e.g., if you are just
1888
interested in logging intermediate computations.
1890
* You want to control the order of execution of various tensor
1891
subclasses explicitly, rather than implicitly via the return of
1894
Independent subclasses of :class:`TorchFunctionMode` are compositional:
1895
modes can be pushed onto a stack using ``with MyMode():``.
1896
When you call functions in the PyTorch API inside your
1897
``__torch_function__`` implementation, by default, they will forward on to
1898
the next mode on the mode stack. If you want recursively call back into
1899
your current ``__torch_function__`` implementation, either explicitly
1900
invoke ``self.__torch_function__(...)``, or use the context manager
1901
``enable_torch_function_mode(self, replace=self.inner)`` to make PyTorch
1902
API self-referential (beware of infinite loops, in this case!)
1904
inner: "TorchFunctionMode"
1910
def __torch_function__(self, func, types, args=(), kwargs=None):
1911
raise NotImplementedError()
1913
def __enter__(self):
1917
def __exit__(self, exc_type, exc_val, exc_tb):
1921
def push(cls, *args, **kwargs):
1922
warnings.warn("`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`")
1923
instance = cls(*args, **kwargs)
1927
def _get_current_function_mode():
1928
stack_len = _len_torch_function_stack()
1929
return _get_function_stack_at(stack_len - 1) if stack_len > 0 else None
1932
def _get_current_function_mode_stack():
1933
stack_len = _len_torch_function_stack()
1934
return [_get_function_stack_at(i) for i in range(stack_len)]
1936
def _push_mode(mode):
1937
_push_on_torch_function_stack(mode)
1941
old = _pop_torch_function_stack()
1945
@contextlib.contextmanager
1946
def _pop_mode_temporarily():
1953
class BaseTorchFunctionMode(TorchFunctionMode):
1954
def __torch_function__(self, func, types, args=(), kwargs=None):
1957
return func(*args, **kwargs)
1960
@contextlib.contextmanager
1961
def enable_reentrant_dispatch():
1969
with torch._C._RestorePythonTLSSnapshot():