1
# mypy: allow-untyped-decorators
2
# mypy: allow-untyped-defs
6
from contextlib import contextmanager
22
from torch import _C, _ops, Tensor
23
from torch.utils._exposed_in import exposed_in
25
from . import autograd, utils
28
device_types_t = Optional[Union[str, Sequence[str]]]
29
log = logging.getLogger(__name__)
32
@exposed_in("torch.library")
35
fn: Optional[Callable] = None,
38
mutates_args: Union[str, Iterable[str]],
39
device_types: device_types_t = None,
40
schema: Optional[str] = None,
42
"""Wraps a function into custom operator.
44
Reasons why you may want to create a custom op include:
45
- Wrapping a third-party library or custom kernel to work with PyTorch
46
subsystems like Autograd.
47
- Preventing torch.compile/export/FX tracing from peeking inside your function.
49
This API is used as a decorator around a function (please see examples).
50
The provided function must have type hints; these are needed to interface
51
with PyTorch's various subsystems.
54
name (str): A name for the custom op that looks like "{namespace}::{name}",
55
e.g. "mylib::my_linear". The name is used as the op's stable identifier
56
in PyTorch subsystems (e.g. torch.export, FX graphs).
57
To avoid name collisions, please use your project name as the namespace;
58
e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace.
59
mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates.
60
This MUST be accurate, otherwise, the behavior is undefined. If "unknown",
61
it pessimistically assumes that all inputs to the operator are being mutated.
62
device_types (None | str | Sequence[str]): The device type(s) the function
63
is valid for. If no device type is provided, then the function
64
is used as the default implementation for all device types.
65
Examples: "cpu", "cuda".
66
When registering a device-specific implementation for an operator that accepts no Tensors,
67
we require the operator to have a "device: torch.device argument".
68
schema (None | str): A schema string for the operator. If None
69
(recommended) we'll infer a schema for the operator from its type
70
annotations. We recommend letting us infer a schema unless you
71
have a specific reason not to.
72
Example: "(Tensor x, int y) -> (Tensor, Tensor)".
75
We recommend not passing in a ``schema`` arg and instead letting us infer
76
it from the type annotations. It is error-prone to write your own schema.
77
You may wish to provide your own schema if our interpretation of
78
the type annotation is not what you want.
79
For more info on how to write a schema string, see
80
`here <https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func>`_
84
>>> from torch import Tensor
85
>>> from torch.library import custom_op
86
>>> import numpy as np
88
>>> @custom_op("mylib::numpy_sin", mutates_args=())
89
>>> def numpy_sin(x: Tensor) -> Tensor:
90
>>> x_np = x.cpu().numpy()
91
>>> y_np = np.sin(x_np)
92
>>> return torch.from_numpy(y_np).to(device=x.device)
94
>>> x = torch.randn(3)
96
>>> assert torch.allclose(y, x.sin())
98
>>> # Example of a custom op that only works for one device type.
99
>>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu")
100
>>> def numpy_sin_cpu(x: Tensor) -> Tensor:
102
>>> y_np = np.sin(x_np)
103
>>> return torch.from_numpy(y_np)
105
>>> x = torch.randn(3)
106
>>> y = numpy_sin_cpu(x)
107
>>> assert torch.allclose(y, x.sin())
109
>>> # Example of a custom op that mutates an input
110
>>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu")
111
>>> def numpy_sin_inplace(x: Tensor) -> None:
113
>>> np.sin(x_np, out=x_np)
115
>>> x = torch.randn(3)
116
>>> expected = x.sin()
117
>>> numpy_sin_inplace(x)
118
>>> assert torch.allclose(x, expected)
120
>>> # Example of a factory function
121
>>> @torch.library.custom_op("mylib::bar", mutates_args={}, device_types="cpu")
122
>>> def bar(device: torch.device) -> Tensor:
123
>>> return torch.ones(3)
133
schema_str = torch.library.infer_schema(fn, mutates_args=mutates_args)
137
namespace, opname = name.split("::")
138
result = CustomOpDef(namespace, opname, schema_str, fn)
139
if schema is not None:
140
# Check that schema's alias annotations match those of `mutates_args`.
142
for arg in result._opoverload._schema.arguments:
143
if arg.alias_info is not None and arg.alias_info.is_write:
144
expected.add(arg.name)
145
if expected != set(mutates_args):
147
f"Attempted to create a custom op with `mutates_args={mutates_args}` "
148
f"and `schema={schema}. The schema suggests that the op mutates {expected}"
149
f"which is different from what was provided to us in `mutates_args`. "
150
f"Please make these consistent."
152
result.register_kernel(device_types)(fn)
161
"""CustomOpDef is a wrapper around a function that turns it into a custom op.
163
It has various methods for registering additional behavior for this
166
You should not instantiate CustomOpDef directly; instead, use the
167
:func:`torch.library.custom_op` API.
170
def __init__(self, namespace: str, name: str, schema: str, fn: Callable) -> None:
171
# Fields used to interface with the PyTorch dispatcher
172
self._namespace = namespace
174
self._schema = schema
178
self._backend_fns: Dict[Union[str, None], Callable] = {}
179
self._abstract_fn: Optional[Callable] = None
180
self._setup_context_fn: Optional[Callable] = None
181
self._backward_fn: Optional[Callable] = None
182
self._torch_dispatch_fns: Dict[type, Callable] = {}
183
self._vmap_fn: Optional[Callable] = None
185
self._lib = get_library_allowing_overwrite(self._namespace, self._name)
186
self._register_to_dispatcher()
187
self._disabled_kernel: Set = set()
188
OPDEFS[self._qualname] = self
191
def _qualname(self) -> str:
192
return f"{self._namespace}::{self._name}"
194
def __repr__(self) -> str:
195
return f"<CustomOpDef({self._qualname})>"
198
def set_kernel_enabled(self, device_type: str, enabled: bool = True):
200
Disable or re-enable an already registered kernel for this custom operator.
202
If the kernel is already disabled/enabled, this is a no-op.
205
If a kernel is first disabled and then registered, it is disabled until enabled again.
208
device_type (str): The device type to disable/enable the kernel for.
209
disable (bool): Whether to disable or enable the kernel.
212
>>> inp = torch.randn(1)
214
>>> # define custom op `f`.
215
>>> @custom_op("mylib::f", mutates_args=())
216
>>> def f(x: Tensor) -> Tensor:
217
>>> return torch.zeros(1)
219
>>> print(f(inp)) # tensor([0.]), default kernel
221
>>> @f.register_kernel("cpu")
223
>>> return torch.ones(1)
225
>>> print(f(inp)) # tensor([1.]), CPU kernel
227
>>> # temporarily disable the CPU kernel
228
>>> with f.set_kernel_enabled("cpu", enabled = False):
229
>>> print(f(inp)) # tensor([0.]) with CPU kernel disabled
232
action = "enable" if enabled else "disable"
233
originally_disabled = device_type in self._disabled_kernel
234
if device_type not in self._backend_fns:
236
"Attempted to %s kernel for %s but no kernel was registered for this device type.",
242
if originally_disabled:
244
"Attempted to disable kernel for %s but it was already disabled.",
248
self._disabled_kernel.add(device_type)
249
else: # enable the kernel
250
if not originally_disabled:
252
"Attempted to enable kernel for %s but it was already enabled.",
256
self._disabled_kernel.remove(device_type)
261
# restore original state
262
if originally_disabled:
263
self._disabled_kernel.add(device_type)
265
self._disabled_kernel.discard(device_type)
268
self, device_types: device_types_t, fn: Optional[Callable] = None, /
270
"""Register an implementation for a device type for this operator.
272
Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
273
This API may be used as a decorator.
276
fn (Callable): The function to register as the implementation for
277
the given device types.
278
device_types (str | Sequence[str]): The device device_types to register an impl to.
281
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
283
>>> from torch import Tensor
284
>>> from torch.library import custom_op
285
>>> import numpy as np
287
>>> # Create a custom op that works on cpu
288
>>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
289
>>> def numpy_sin(x: Tensor) -> Tensor:
291
>>> y_np = np.sin(x_np)
292
>>> return torch.from_numpy(y_np)
294
>>> # Add implementations for the cuda device
295
>>> @numpy_sin.register_kernel("cuda")
297
>>> x_np = x.cpu().numpy()
298
>>> y_np = np.sin(x_np)
299
>>> return torch.from_numpy(y_np).to(device=x.device)
301
>>> x_cpu = torch.randn(3)
302
>>> x_cuda = x_cpu.cuda()
303
>>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin())
304
>>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
309
if device_types is None or isinstance(device_types, str):
310
dtypes: List[Union[str, None]] = [device_types]
312
dtypes = list(device_types)
313
for device_type in dtypes:
314
if device_type not in self._backend_fns:
316
def backend_impl(*args, **kwargs):
317
# Checks the assumption that outputs cannot alias
318
# inputs or other outputs.
320
id(tensor.untyped_storage())
321
for tensor in iter_tensors(args, kwargs)
324
result = self._backend_fns[device_type](*args, **kwargs)
326
tuple_result = result
327
if not isinstance(result, tuple):
328
tuple_result = (result,)
329
for tensor in iter_tensors(tuple_result, {}):
330
key = id(tensor.untyped_storage())
331
if id(tensor.untyped_storage()) in storages:
332
fn = self._backend_fns[device_type]
333
module = inspect.getmodule(fn)
335
f"{self._name} (with implementation in {module}): "
336
f"The output of this custom operator (1) must not "
337
f"also be an input to this custom operator and "
338
f"(2) may not alias any inputs to this custom operator "
339
f"or other returns. "
340
f"The most common way to trigger this error is if "
341
f"we have y = custom_op(x) and y and x are the same Tensor. "
342
f"Please instead return a clone of the offending output "
343
f"tensor(s) (e.g. return x.clone()) or refactor the custom "
344
f"operator to not return y."
349
if device_type is None:
351
self._name, backend_impl, "CompositeExplicitAutograd"
357
_C._dispatch_key_for_device(device_type),
360
# Wrap function to choose between the default implementation or the device-specific
361
# implementation depending on if the kernel is disabled.
362
@torch._disable_dynamo
363
def wrapped_fn(*args, **kwargs):
364
if device_type in self._disabled_kernel:
365
return self._init_fn(*args, **kwargs)
367
return fn(*args, **kwargs)
369
self._backend_fns[device_type] = wrapped_fn
372
if device_types is not None and not utils.has_tensor_arg(
373
self._opoverload._schema
375
device_arg_index = utils.get_device_arg_index(self._opoverload._schema)
376
if device_arg_index is None:
378
"Functions without tensor inputs are required to have a `device: torch.device` argument"
380
self._register_backend_select_dispatcher(device_arg_index)
382
# See NOTE: [Supporting decorator and non-decorator usage]
387
def register_fake(self, fn: Callable, /) -> Callable:
388
r"""Register a FakeTensor implementation for this custom op.
390
This is necessary to get the operator to work efficiently with torch.compile.
392
The Fake impl (sometimes also known as a meta kernel or abstract impl)
393
specifies the behavior of this operator on Tensors that carry no data.
394
Given some input Tensors with certain properties
395
(sizes/strides/storage_offset/device), it specifies what the properties of
396
the output Tensors are.
398
Please see :func:`torch.library.impl_abstract` for more details.
401
fn (Callable): The function to register as the FakeTensor
406
>>> import numpy as np
407
>>> from torch import Tensor
409
>>> # Example 1: an operator without data-dependent output shape
410
>>> @torch.library.custom_op("mylib::linear", mutates_args=())
411
>>> def linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
412
>>> return (x @ weight.t()) + bias
414
>>> @linear.register_fake
415
>>> def _(x, weight, bias):
416
>>> assert x.dim() == 2
417
>>> assert weight.dim() == 2
418
>>> assert bias.dim() == 1
419
>>> assert x.shape[1] == weight.shape[1]
420
>>> assert weight.shape[0] == bias.shape[0]
421
>>> assert x.device == weight.device
422
>>> return x.new_empty(x.size(0), weight.size(0))
424
>>> x = torch.randn(2, 2)
425
>>> weight = torch.randn(2, 2)
426
>>> bias = torch.randn(2)
427
>>> # xdoctest: +SKIP("Requires Python <= 3.11")
428
>>> out = torch.compile(linear, fullgraph=True)(x, weight, bias)
429
>>> # xdoctest: +SKIP("Requires Python <= 3.11")
430
>>> assert torch.allclose(out, torch.nn.functional.linear(x, weight, bias))
432
>>> # Example 2: an operator with data-dependent output shape
433
>>> @torch.library.custom_op("mylib::nonzero", mutates_args=())
434
>>> def nonzero(x: Tensor) -> Tensor:
435
>>> x_np = x.cpu().numpy()
436
>>> res = np.stack(np.nonzero(x_np), axis=1)
437
>>> return torch.tensor(res, device=x.device)
439
>>> @nonzero.register_fake
441
>>> # Number of nonzero-elements is data-dependent.
442
>>> # Since we cannot peek at the data in an abstract impl,
443
>>> # we use the ctx object to construct a new symint that
444
>>> # represents the data-dependent size.
445
>>> ctx = torch.library.get_ctx()
446
>>> nnz = ctx.new_dynamic_size()
447
>>> shape = [nnz, x.dim()]
448
>>> result = x.new_empty(shape, dtype=torch.int64)
451
>>> x = torch.tensor([0, 1, 2, 0, 0, 1])
452
>>> # xdoctest: +SKIP("Requires Python <= 3.11")
453
>>> out = torch.compile(nonzero, fullgraph=True)(x)
454
>>> # xdoctest: +SKIP("Requires Python <= 3.11")
455
>>> assert torch.allclose(out, x.nonzero())
458
self._abstract_fn = fn
461
def register_torch_dispatch(
462
self, torch_dispatch_class: Any, fn: Optional[Callable] = None, /
464
r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``.
466
This allows for open registration to specify the behavior between the operator
467
and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class``
468
or the operator directly.
470
Please see :func:`torch.library.register_torch_dispatch` for examples and more details.
474
if torch_dispatch_class not in self._torch_dispatch_fns:
476
def inner(*args, **kwargs):
477
return self._torch_dispatch_fns[torch_dispatch_class](
481
self._lib._register_torch_dispatch_rule(
482
self._name, torch_dispatch_class, inner
484
self._torch_dispatch_fns[torch_dispatch_class] = fn
492
def register_autograd(
497
setup_context: Optional[Callable] = None,
499
r"""Register a backward formula for this custom op.
501
In order for an operator to work with autograd, you need to register
503
1. You must tell us how to compute gradients during the backward pass
504
by providing us a "backward" function.
505
2. If you need any values from the forward to compute gradients, you can
506
use `setup_context` to save values for backward.
508
``backward_fn`` runs during the backward pass. It accepts ``(ctx, *grads)``:
509
- ``grads`` is one or more gradients. The number of gradients matches
510
the number of outputs of the operator.
511
The ``ctx`` object is `the same ctx object <context_method_mixins>`_ used by
512
:class:`torch.autograd.Function`. The semantics of ``backward_fn`` are the
513
same as :meth:`torch.autograd.Function.backward`.
515
``setup_context(ctx, inputs, output)`` runs during the forward pass.
516
Please save quantities needed for backward onto the ``ctx`` object via
517
either :meth:`torch.autograd.function.FunctionCtx.save_for_backward`
518
or assigning them as attributes of ``ctx``. If your custom op has
519
kwarg-only arguments, we expect the signature of ``setup_context``
520
to be ``setup_context(ctx, inputs, keyword_only_inputs, output)``.
522
Both ``setup_context_fn`` and ``backward_fn`` must be traceable. That is,
523
they may not directly access :meth:`torch.Tensor.data_ptr` and they must
524
not depend on or mutate global state. If you need a non-traceable backward,
525
you can make it a separate custom_op that you call inside ``backward_fn``.
529
>>> import numpy as np
530
>>> from torch import Tensor
532
>>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=())
533
>>> def numpy_sin(x: Tensor) -> Tensor:
534
>>> x_np = x.cpu().numpy()
535
>>> y_np = np.sin(x_np)
536
>>> return torch.from_numpy(y_np).to(device=x.device)
538
>>> def setup_context(ctx, inputs, output) -> Tensor:
540
>>> ctx.save_for_backward(x)
542
>>> def backward(ctx, grad):
543
>>> x, = ctx.saved_tensors
544
>>> return grad * x.cos()
546
>>> numpy_sin.register_autograd(backward, setup_context=setup_context)
548
>>> x = torch.randn(3, requires_grad=True)
550
>>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y))
551
>>> assert torch.allclose(grad_x, x.cos())
553
>>> # Example with a keyword-only arg
554
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
555
>>> def numpy_mul(x: Tensor, *, val: float) -> Tensor:
556
>>> x_np = x.cpu().numpy()
557
>>> y_np = x_np * val
558
>>> return torch.from_numpy(y_np).to(device=x.device)
560
>>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor:
561
>>> ctx.val = keyword_only_inputs["val"]
563
>>> def backward(ctx, grad):
564
>>> return grad * ctx.val
566
>>> numpy_mul.register_autograd(backward, setup_context=setup_context)
568
>>> x = torch.randn(3, requires_grad=True)
569
>>> y = numpy_mul(x, val=3.14)
570
>>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y))
571
>>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
574
schema = self._opoverload._schema
575
if not utils.is_functional_schema(schema):
577
f"Cannot register autograd formula for non-functional operator "
578
f"{self} with schema {schema}. Please create "
579
f"a functional operator and register an autograd formula for that."
582
self._backward_fn = backward
583
self._setup_context_fn = setup_context
585
def _register_to_dispatcher(self) -> None:
587
schema_str = self._name + self._schema
588
cpp_schema = _C.parse_schema(schema_str)
589
if utils.has_kwarg_only_tensors(cpp_schema):
590
# If you want to support this, the progression is:
591
# - supporting kwarg-only Tensors that are non-differentiable
592
# - supporting kwarg-only Tensors (regardless of differentiability)
593
raise NotImplementedError(
594
f"custom_op with kwarg-only Tensor args. Please make your "
595
f"tensors not kwarg-only. Got: {schema_str}"
600
tags=[_C.Tag.pt2_compliant_tag, _C.Tag.needs_fixed_stride_order],
602
self._opoverload = utils.lookup_op(self._qualname)
604
def fake_impl(*args, **kwargs):
605
if self._abstract_fn is None:
606
if utils.can_generate_trivial_fake_impl(self._opoverload):
609
f"There was no fake impl registered for {self}. "
610
f"This is necessary for torch.compile/export/fx tracing to work. "
611
f"Please use `{self._init_fn.__name__}.register_fake` to add an "
614
return self._abstract_fn(*args, **kwargs)
616
lib._register_fake(self._name, fake_impl, _stacklevel=4)
618
autograd_impl = autograd.make_autograd_impl(self._opoverload, self)
619
lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True)
621
schema = self._opoverload._schema
622
if schema.is_mutable:
624
def adinplaceorview_impl(keyset, *args, **kwargs):
625
for arg, val in utils.zip_schema(schema, args, kwargs):
626
if not arg.alias_info:
628
if not arg.alias_info.is_write:
630
if isinstance(val, Tensor):
631
torch.autograd.graph.increment_version(val)
632
elif isinstance(val, (tuple, list)):
634
if isinstance(v, Tensor):
635
torch.autograd.graph.increment_version(v)
636
with _C._AutoDispatchBelowADInplaceOrView():
637
return self._opoverload.redispatch(
638
keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs
643
adinplaceorview_impl,
648
def _register_backend_select_dispatcher(self, device_arg_index: int):
650
Switch on the device argument to select the correct backend to dispatch to.
653
def backend_select(keyset, *args, **kwargs):
654
device = args[device_arg_index].type
655
if device not in self._backend_fns:
657
f"{self._name} does not have a kernel registered for {device}. "
658
"Please use register_kernel to do so."
660
dispatch_key = _C._dispatch_key_for_device(device)
661
dispatch_key = getattr(_C.DispatchKey, dispatch_key)
662
return self._opoverload.redispatch(
663
_C.DispatchKeySet(dispatch_key), *args, **kwargs
666
self._lib.impl(self._name, backend_select, "BackendSelect", with_keyset=True)
668
def __call__(self, *args, **kwargs):
669
return self._opoverload(*args, **kwargs)
673
func: Optional[Callable] = None,
675
r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op.
677
This API may be used as a decorator.
679
In order for an operator to work with :func:`torch.vmap`, you may need to register a
680
vmap implementation in the following signature:
682
``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``,
684
where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``.
686
It specifies how do we compute the batched version of ``op`` given inputs with an additional
687
dimension (specified by ``in_dims``).
689
For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None``
690
if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer
691
specifying what dimension of the Tensor is being vmapped over.
693
``info`` is a collection of additional metadata that may be helpful:
694
``info.batch_size`` specifies the size of the dimension being vmapped over, while
695
``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`.
697
The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``,
698
``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim``
699
per output that specifies if the output has the vmapped dimension and what index it is in.
703
>>> import numpy as np
704
>>> from torch import Tensor
705
>>> from typing import Tuple
707
>>> def to_numpy(tensor):
708
>>> return tensor.cpu().numpy()
710
>>> lib = torch.library.Library("mylib", "FRAGMENT")
711
>>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=())
712
>>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]:
713
>>> x_np = to_numpy(x)
714
>>> dx = torch.tensor(3 * x_np ** 2, device=x.device)
715
>>> return torch.tensor(x_np ** 3, device=x.device), dx
717
>>> def numpy_cube_vmap(info, in_dims, x):
718
>>> result = numpy_cube(x)
719
>>> return result, (in_dims[0], in_dims[0])
721
>>> numpy_cube.register_vmap(numpy_cube_vmap)
723
>>> x = torch.randn(3)
724
>>> torch.vmap(numpy_cube)(x)
726
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
727
>>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
728
>>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
730
>>> @numpy_mul.register_vmap
731
>>> def numpy_mul_vmap(info, in_dims, x, y):
732
>>> x_bdim, y_bdim = in_dims
733
>>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
734
>>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
736
>>> result = result.movedim(-1, 0)
740
>>> x = torch.randn(3)
741
>>> y = torch.randn(3)
742
>>> torch.vmap(numpy_mul)(x, y)
744
from torch._functorch.autograd_function import custom_function_call_vmap_helper
745
from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter
748
need_register = self._vmap_fn is None
753
def wrapped_func(keyset, *args, **kwargs):
754
interpreter = retrieve_current_functorch_interpreter()
755
return custom_function_call_vmap_helper(
756
interpreter, self._vmap_fn, self._opoverload, *args, **kwargs
760
self._name, wrapped_func, "FuncTorchBatched", with_keyset=True
766
return register(func)
769
# NOTE: [Supporting decorator and non-decorator usage]
771
# Some APIs may be both used as a decorator and not as a decorator.
777
# >>> # Usage 1: not as a decorator
778
# >>> numpy_sin.register_kernel("cuda", fn)
780
# >>> # Usage 2: as a decorator
781
# >>> @numpy_sin.register_kernel("cuda")
785
# The way we support this is that `register_kernel` accepts an optional `fn`.
786
# If `fn` is provided (Usage 1), then we know that the user is using it not
788
# If `fn` is not provided (Usage 2), then `register_kernel` needs to return a
792
OPDEF_TO_LIB: Dict[str, "torch.library.Library"] = {}
793
OPDEFS: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
796
def get_library_allowing_overwrite(
797
namespace: str, name: str
798
) -> "torch.library.Library":
799
qualname = f"{namespace}::{name}"
801
if qualname in OPDEF_TO_LIB:
802
OPDEF_TO_LIB[qualname]._destroy()
803
del OPDEF_TO_LIB[qualname]
805
lib = torch.library.Library(namespace, "FRAGMENT") # noqa: TOR901
806
OPDEF_TO_LIB[qualname] = lib
811
args: Tuple[Any], kwargs: Dict[str, Any], allowed_nesting: int = 1
812
) -> Iterator[Tensor]:
814
if isinstance(arg, Tensor):
816
elif allowed_nesting > 0 and isinstance(arg, (tuple, list)):
817
yield from iter_tensors(tuple(arg), {}, allowed_nesting - 1)
820
yield from check(arg)
821
for kwarg in kwargs.values():
822
yield from check(kwarg)
826
op: Union[CustomOpDef, _ops.OpOverload, str]
827
) -> Optional[CustomOpDef]:
828
if isinstance(op, CustomOpDef):
830
if isinstance(op, _ops.OpOverload):
832
assert isinstance(op, str)