pytorch

Форк
0
/
custom_ops.py 
835 строк · 33.4 Кб
1
# mypy: allow-untyped-decorators
2
# mypy: allow-untyped-defs
3
import inspect
4
import logging
5
import weakref
6
from contextlib import contextmanager
7
from typing import (
8
    Any,
9
    Callable,
10
    Dict,
11
    Iterable,
12
    Iterator,
13
    List,
14
    Optional,
15
    Sequence,
16
    Set,
17
    Tuple,
18
    Union,
19
)
20

21
import torch
22
from torch import _C, _ops, Tensor
23
from torch.utils._exposed_in import exposed_in
24

25
from . import autograd, utils
26

27

28
device_types_t = Optional[Union[str, Sequence[str]]]
29
log = logging.getLogger(__name__)
30

31

32
@exposed_in("torch.library")
33
def custom_op(
34
    name: str,
35
    fn: Optional[Callable] = None,
36
    /,
37
    *,
38
    mutates_args: Union[str, Iterable[str]],
39
    device_types: device_types_t = None,
40
    schema: Optional[str] = None,
41
) -> Callable:
42
    """Wraps a function into custom operator.
43

44
    Reasons why you may want to create a custom op include:
45
    - Wrapping a third-party library or custom kernel to work with PyTorch
46
    subsystems like Autograd.
47
    - Preventing torch.compile/export/FX tracing from peeking inside your function.
48

49
    This API is used as a decorator around a function (please see examples).
50
    The provided function must have type hints; these are needed to interface
51
    with PyTorch's various subsystems.
52

53
    Args:
54
        name (str): A name for the custom op that looks like "{namespace}::{name}",
55
            e.g. "mylib::my_linear". The name is used as the op's stable identifier
56
            in PyTorch subsystems (e.g. torch.export, FX graphs).
57
            To avoid name collisions, please use your project name as the namespace;
58
            e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace.
59
        mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates.
60
            This MUST be accurate, otherwise, the behavior is undefined. If "unknown",
61
            it pessimistically assumes that all inputs to the operator are being mutated.
62
        device_types (None | str | Sequence[str]): The device type(s) the function
63
            is valid for. If no device type is provided, then the function
64
            is used as the default implementation for all device types.
65
            Examples: "cpu", "cuda".
66
            When registering a device-specific implementation for an operator that accepts no Tensors,
67
            we require the operator to have a "device: torch.device argument".
68
        schema (None | str): A schema string for the operator. If None
69
            (recommended) we'll infer a schema for the operator from its type
70
            annotations. We recommend letting us infer a schema unless you
71
            have a specific reason not to.
72
            Example: "(Tensor x, int y) -> (Tensor, Tensor)".
73

74
    .. note::
75
        We recommend not passing in a ``schema`` arg and instead letting us infer
76
        it from the type annotations. It is error-prone to write your own schema.
77
        You may wish to provide your own schema if our interpretation of
78
        the type annotation is not what you want.
79
        For more info on how to write a schema string, see
80
        `here <https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func>`_
81

82
    Examples::
83
        >>> import torch
84
        >>> from torch import Tensor
85
        >>> from torch.library import custom_op
86
        >>> import numpy as np
87
        >>>
88
        >>> @custom_op("mylib::numpy_sin", mutates_args=())
89
        >>> def numpy_sin(x: Tensor) -> Tensor:
90
        >>>     x_np = x.cpu().numpy()
91
        >>>     y_np = np.sin(x_np)
92
        >>>     return torch.from_numpy(y_np).to(device=x.device)
93
        >>>
94
        >>> x = torch.randn(3)
95
        >>> y = numpy_sin(x)
96
        >>> assert torch.allclose(y, x.sin())
97
        >>>
98
        >>> # Example of a custom op that only works for one device type.
99
        >>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu")
100
        >>> def numpy_sin_cpu(x: Tensor) -> Tensor:
101
        >>>     x_np = x.numpy()
102
        >>>     y_np = np.sin(x_np)
103
        >>>     return torch.from_numpy(y_np)
104
        >>>
105
        >>> x = torch.randn(3)
106
        >>> y = numpy_sin_cpu(x)
107
        >>> assert torch.allclose(y, x.sin())
108
        >>>
109
        >>> # Example of a custom op that mutates an input
110
        >>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu")
111
        >>> def numpy_sin_inplace(x: Tensor) -> None:
112
        >>>     x_np = x.numpy()
113
        >>>     np.sin(x_np, out=x_np)
114
        >>>
115
        >>> x = torch.randn(3)
116
        >>> expected = x.sin()
117
        >>> numpy_sin_inplace(x)
118
        >>> assert torch.allclose(x, expected)
119
        >>>
120
        >>> # Example of a factory function
121
        >>> @torch.library.custom_op("mylib::bar", mutates_args={}, device_types="cpu")
122
        >>> def bar(device: torch.device) -> Tensor:
123
        >>>     return torch.ones(3)
124
        >>>
125
        >>> bar("cpu")
126

127
    """
128

129
    def inner(fn):
130
        import torch
131

132
        if schema is None:
133
            schema_str = torch.library.infer_schema(fn, mutates_args=mutates_args)
134
        else:
135
            schema_str = schema
136

137
        namespace, opname = name.split("::")
138
        result = CustomOpDef(namespace, opname, schema_str, fn)
139
        if schema is not None:
140
            # Check that schema's alias annotations match those of `mutates_args`.
141
            expected = set()
142
            for arg in result._opoverload._schema.arguments:
143
                if arg.alias_info is not None and arg.alias_info.is_write:
144
                    expected.add(arg.name)
145
            if expected != set(mutates_args):
146
                raise ValueError(
147
                    f"Attempted to create a custom op with `mutates_args={mutates_args}` "
148
                    f"and `schema={schema}. The schema suggests that the op mutates {expected}"
149
                    f"which is different from what was provided to us in `mutates_args`. "
150
                    f"Please make these consistent."
151
                )
152
        result.register_kernel(device_types)(fn)
153
        return result
154

155
    if fn is None:
156
        return inner
157
    return inner(fn)
158

159

160
class CustomOpDef:
161
    """CustomOpDef is a wrapper around a function that turns it into a custom op.
162

163
    It has various methods for registering additional behavior for this
164
    custom op.
165

166
    You should not instantiate CustomOpDef directly; instead, use the
167
    :func:`torch.library.custom_op` API.
168
    """
169

170
    def __init__(self, namespace: str, name: str, schema: str, fn: Callable) -> None:
171
        # Fields used to interface with the PyTorch dispatcher
172
        self._namespace = namespace
173
        self._name = name
174
        self._schema = schema
175

176
        self._init_fn = fn
177

178
        self._backend_fns: Dict[Union[str, None], Callable] = {}
179
        self._abstract_fn: Optional[Callable] = None
180
        self._setup_context_fn: Optional[Callable] = None
181
        self._backward_fn: Optional[Callable] = None
182
        self._torch_dispatch_fns: Dict[type, Callable] = {}
183
        self._vmap_fn: Optional[Callable] = None
184

185
        self._lib = get_library_allowing_overwrite(self._namespace, self._name)
186
        self._register_to_dispatcher()
187
        self._disabled_kernel: Set = set()
188
        OPDEFS[self._qualname] = self
189

190
    @property
191
    def _qualname(self) -> str:
192
        return f"{self._namespace}::{self._name}"
193

194
    def __repr__(self) -> str:
195
        return f"<CustomOpDef({self._qualname})>"
196

197
    @contextmanager
198
    def set_kernel_enabled(self, device_type: str, enabled: bool = True):
199
        """
200
        Disable or re-enable an already registered kernel for this custom operator.
201

202
        If the kernel is already disabled/enabled, this is a no-op.
203

204
        Note:
205
            If a kernel is first disabled and then registered, it is disabled until enabled again.
206

207
        Args:
208
            device_type (str): The device type to disable/enable the kernel for.
209
            disable (bool): Whether to disable or enable the kernel.
210

211
        Example:
212
            >>> inp = torch.randn(1)
213
            >>>
214
            >>> # define custom op `f`.
215
            >>> @custom_op("mylib::f", mutates_args=())
216
            >>> def f(x: Tensor) -> Tensor:
217
            >>>     return torch.zeros(1)
218
            >>>
219
            >>> print(f(inp))  # tensor([0.]), default kernel
220
            >>>
221
            >>> @f.register_kernel("cpu")
222
            >>> def _(x):
223
            >>>     return torch.ones(1)
224
            >>>
225
            >>> print(f(inp))  # tensor([1.]), CPU kernel
226
            >>>
227
            >>> # temporarily disable the CPU kernel
228
            >>> with f.set_kernel_enabled("cpu", enabled = False):
229
            >>>     print(f(inp))  # tensor([0.]) with CPU kernel disabled
230

231
        """
232
        action = "enable" if enabled else "disable"
233
        originally_disabled = device_type in self._disabled_kernel
234
        if device_type not in self._backend_fns:
235
            log.warning(
236
                "Attempted to %s kernel for %s but no kernel was registered for this device type.",
237
                action,
238
                device_type,
239
            )
240

241
        if not enabled:
242
            if originally_disabled:
243
                log.warning(
244
                    "Attempted to disable kernel for %s but it was already disabled.",
245
                    device_type,
246
                )
247
            else:
248
                self._disabled_kernel.add(device_type)
249
        else:  # enable the kernel
250
            if not originally_disabled:
251
                log.warning(
252
                    "Attempted to enable kernel for  %s but it was already enabled.",
253
                    device_type,
254
                )
255
            else:
256
                self._disabled_kernel.remove(device_type)
257

258
        try:
259
            yield
260
        finally:
261
            # restore original state
262
            if originally_disabled:
263
                self._disabled_kernel.add(device_type)
264
            else:
265
                self._disabled_kernel.discard(device_type)
266

267
    def register_kernel(
268
        self, device_types: device_types_t, fn: Optional[Callable] = None, /
269
    ) -> Callable:
270
        """Register an implementation for a device type for this operator.
271

272
        Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
273
        This API may be used as a decorator.
274

275
        Args:
276
            fn (Callable): The function to register as the implementation for
277
                the given device types.
278
            device_types (str | Sequence[str]): The device device_types to register an impl to.
279

280
        Examples::
281
            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
282
            >>> import torch
283
            >>> from torch import Tensor
284
            >>> from torch.library import custom_op
285
            >>> import numpy as np
286
            >>>
287
            >>> # Create a custom op that works on cpu
288
            >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
289
            >>> def numpy_sin(x: Tensor) -> Tensor:
290
            >>>     x_np = x.numpy()
291
            >>>     y_np = np.sin(x_np)
292
            >>>     return torch.from_numpy(y_np)
293
            >>>
294
            >>> # Add implementations for the cuda device
295
            >>> @numpy_sin.register_kernel("cuda")
296
            >>> def _(x):
297
            >>>     x_np = x.cpu().numpy()
298
            >>>     y_np = np.sin(x_np)
299
            >>>     return torch.from_numpy(y_np).to(device=x.device)
300
            >>>
301
            >>> x_cpu = torch.randn(3)
302
            >>> x_cuda = x_cpu.cuda()
303
            >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin())
304
            >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
305

306
        """
307

308
        def inner(fn):
309
            if device_types is None or isinstance(device_types, str):
310
                dtypes: List[Union[str, None]] = [device_types]
311
            else:
312
                dtypes = list(device_types)
313
            for device_type in dtypes:
314
                if device_type not in self._backend_fns:
315

316
                    def backend_impl(*args, **kwargs):
317
                        # Checks the assumption that outputs cannot alias
318
                        # inputs or other outputs.
319
                        storages = {
320
                            id(tensor.untyped_storage())
321
                            for tensor in iter_tensors(args, kwargs)
322
                        }
323

324
                        result = self._backend_fns[device_type](*args, **kwargs)
325

326
                        tuple_result = result
327
                        if not isinstance(result, tuple):
328
                            tuple_result = (result,)
329
                        for tensor in iter_tensors(tuple_result, {}):
330
                            key = id(tensor.untyped_storage())
331
                            if id(tensor.untyped_storage()) in storages:
332
                                fn = self._backend_fns[device_type]
333
                                module = inspect.getmodule(fn)
334
                                raise RuntimeError(
335
                                    f"{self._name} (with implementation in {module}): "
336
                                    f"The output of this custom operator (1) must not "
337
                                    f"also be an input to this custom operator and "
338
                                    f"(2) may not alias any inputs to this custom operator "
339
                                    f"or other returns. "
340
                                    f"The most common way to trigger this error is if "
341
                                    f"we have y = custom_op(x) and y and x are the same Tensor. "
342
                                    f"Please instead return a clone of the offending output "
343
                                    f"tensor(s) (e.g. return x.clone()) or refactor the custom "
344
                                    f"operator to not return y."
345
                                )
346
                            storages.add(key)
347
                        return result
348

349
                    if device_type is None:
350
                        self._lib.impl(
351
                            self._name, backend_impl, "CompositeExplicitAutograd"
352
                        )
353
                    else:
354
                        self._lib.impl(
355
                            self._name,
356
                            backend_impl,
357
                            _C._dispatch_key_for_device(device_type),
358
                        )
359

360
                # Wrap function to choose between the default implementation or the device-specific
361
                # implementation depending on if the kernel is disabled.
362
                @torch._disable_dynamo
363
                def wrapped_fn(*args, **kwargs):
364
                    if device_type in self._disabled_kernel:
365
                        return self._init_fn(*args, **kwargs)
366
                    else:
367
                        return fn(*args, **kwargs)
368

369
                self._backend_fns[device_type] = wrapped_fn
370
            return fn
371

372
        if device_types is not None and not utils.has_tensor_arg(
373
            self._opoverload._schema
374
        ):
375
            device_arg_index = utils.get_device_arg_index(self._opoverload._schema)
376
            if device_arg_index is None:
377
                raise ValueError(
378
                    "Functions without tensor inputs are required to have a `device: torch.device` argument"
379
                )
380
            self._register_backend_select_dispatcher(device_arg_index)
381

382
        # See NOTE: [Supporting decorator and non-decorator usage]
383
        if fn is None:
384
            return inner
385
        return inner(fn)
386

387
    def register_fake(self, fn: Callable, /) -> Callable:
388
        r"""Register a FakeTensor implementation for this custom op.
389

390
        This is necessary to get the operator to work efficiently with torch.compile.
391

392
        The Fake impl (sometimes also known as a meta kernel or abstract impl)
393
        specifies the behavior of this operator on Tensors that carry no data.
394
        Given some input Tensors with certain properties
395
        (sizes/strides/storage_offset/device), it specifies what the properties of
396
        the output Tensors are.
397

398
        Please see :func:`torch.library.impl_abstract` for more details.
399

400
        Args:
401
            fn (Callable): The function to register as the FakeTensor
402
                implementation.
403

404
        Examples:
405
            >>> import torch
406
            >>> import numpy as np
407
            >>> from torch import Tensor
408
            >>>
409
            >>> # Example 1: an operator without data-dependent output shape
410
            >>> @torch.library.custom_op("mylib::linear", mutates_args=())
411
            >>> def linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
412
            >>>     return (x @ weight.t()) + bias
413
            >>>
414
            >>> @linear.register_fake
415
            >>> def _(x, weight, bias):
416
            >>>     assert x.dim() == 2
417
            >>>     assert weight.dim() == 2
418
            >>>     assert bias.dim() == 1
419
            >>>     assert x.shape[1] == weight.shape[1]
420
            >>>     assert weight.shape[0] == bias.shape[0]
421
            >>>     assert x.device == weight.device
422
            >>>     return x.new_empty(x.size(0), weight.size(0))
423
            >>>
424
            >>> x = torch.randn(2, 2)
425
            >>> weight = torch.randn(2, 2)
426
            >>> bias = torch.randn(2)
427
            >>> # xdoctest: +SKIP("Requires Python <= 3.11")
428
            >>> out = torch.compile(linear, fullgraph=True)(x, weight, bias)
429
            >>> # xdoctest: +SKIP("Requires Python <= 3.11")
430
            >>> assert torch.allclose(out, torch.nn.functional.linear(x, weight, bias))
431
            >>>
432
            >>> # Example 2: an operator with data-dependent output shape
433
            >>> @torch.library.custom_op("mylib::nonzero", mutates_args=())
434
            >>> def nonzero(x: Tensor) -> Tensor:
435
            >>>     x_np = x.cpu().numpy()
436
            >>>     res = np.stack(np.nonzero(x_np), axis=1)
437
            >>>     return torch.tensor(res, device=x.device)
438
            >>>
439
            >>> @nonzero.register_fake
440
            >>> def _(x):
441
            >>>     # Number of nonzero-elements is data-dependent.
442
            >>>     # Since we cannot peek at the data in an abstract impl,
443
            >>>     # we use the ctx object to construct a new symint that
444
            >>>     # represents the data-dependent size.
445
            >>>     ctx = torch.library.get_ctx()
446
            >>>     nnz = ctx.new_dynamic_size()
447
            >>>     shape = [nnz, x.dim()]
448
            >>>     result = x.new_empty(shape, dtype=torch.int64)
449
            >>>     return result
450
            >>>
451
            >>> x = torch.tensor([0, 1, 2, 0, 0, 1])
452
            >>> # xdoctest: +SKIP("Requires Python <= 3.11")
453
            >>> out = torch.compile(nonzero, fullgraph=True)(x)
454
            >>> # xdoctest: +SKIP("Requires Python <= 3.11")
455
            >>> assert torch.allclose(out, x.nonzero())
456

457
        """
458
        self._abstract_fn = fn
459
        return fn
460

461
    def register_torch_dispatch(
462
        self, torch_dispatch_class: Any, fn: Optional[Callable] = None, /
463
    ) -> Callable:
464
        r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``.
465

466
        This allows for open registration to specify the behavior between the operator
467
        and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class``
468
        or the operator directly.
469

470
        Please see :func:`torch.library.register_torch_dispatch` for examples and more details.
471
        """
472

473
        def register(fn):
474
            if torch_dispatch_class not in self._torch_dispatch_fns:
475

476
                def inner(*args, **kwargs):
477
                    return self._torch_dispatch_fns[torch_dispatch_class](
478
                        *args, **kwargs
479
                    )
480

481
                self._lib._register_torch_dispatch_rule(
482
                    self._name, torch_dispatch_class, inner
483
                )
484
            self._torch_dispatch_fns[torch_dispatch_class] = fn
485
            return fn
486

487
        if fn is None:
488
            return register
489
        else:
490
            return register(fn)
491

492
    def register_autograd(
493
        self,
494
        backward: Callable,
495
        /,
496
        *,
497
        setup_context: Optional[Callable] = None,
498
    ) -> None:
499
        r"""Register a backward formula for this custom op.
500

501
        In order for an operator to work with autograd, you need to register
502
        a backward formula:
503
        1. You must tell us how to compute gradients during the backward pass
504
        by providing us a "backward" function.
505
        2. If you need any values from the forward to compute gradients, you can
506
        use `setup_context` to save values for backward.
507

508
        ``backward_fn`` runs during the backward pass. It accepts ``(ctx, *grads)``:
509
        - ``grads`` is one or more gradients. The number of gradients matches
510
        the number of outputs of the operator.
511
        The ``ctx`` object is `the same ctx object <context_method_mixins>`_ used by
512
        :class:`torch.autograd.Function`. The semantics of ``backward_fn`` are the
513
        same as :meth:`torch.autograd.Function.backward`.
514

515
        ``setup_context(ctx, inputs, output)`` runs during the forward pass.
516
        Please save quantities needed for backward onto the ``ctx`` object via
517
        either :meth:`torch.autograd.function.FunctionCtx.save_for_backward`
518
        or assigning them as attributes of ``ctx``. If your custom op has
519
        kwarg-only arguments, we expect the signature of ``setup_context``
520
        to be ``setup_context(ctx, inputs, keyword_only_inputs, output)``.
521

522
        Both ``setup_context_fn`` and ``backward_fn`` must be traceable. That is,
523
        they may not directly access :meth:`torch.Tensor.data_ptr` and they must
524
        not depend on or mutate global state. If you need a non-traceable backward,
525
        you can make it a separate custom_op that you call inside ``backward_fn``.
526

527
        Examples:
528
            >>> import torch
529
            >>> import numpy as np
530
            >>> from torch import Tensor
531
            >>>
532
            >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=())
533
            >>> def numpy_sin(x: Tensor) -> Tensor:
534
            >>>     x_np = x.cpu().numpy()
535
            >>>     y_np = np.sin(x_np)
536
            >>>     return torch.from_numpy(y_np).to(device=x.device)
537
            >>>
538
            >>> def setup_context(ctx, inputs, output) -> Tensor:
539
            >>>     x, = inputs
540
            >>>     ctx.save_for_backward(x)
541
            >>>
542
            >>> def backward(ctx, grad):
543
            >>>     x, = ctx.saved_tensors
544
            >>>     return grad * x.cos()
545
            >>>
546
            >>> numpy_sin.register_autograd(backward, setup_context=setup_context)
547
            >>>
548
            >>> x = torch.randn(3, requires_grad=True)
549
            >>> y = numpy_sin(x)
550
            >>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y))
551
            >>> assert torch.allclose(grad_x, x.cos())
552
            >>>
553
            >>> # Example with a keyword-only arg
554
            >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
555
            >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor:
556
            >>>     x_np = x.cpu().numpy()
557
            >>>     y_np = x_np * val
558
            >>>     return torch.from_numpy(y_np).to(device=x.device)
559
            >>>
560
            >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor:
561
            >>>     ctx.val = keyword_only_inputs["val"]
562
            >>>
563
            >>> def backward(ctx, grad):
564
            >>>     return grad * ctx.val
565
            >>>
566
            >>> numpy_mul.register_autograd(backward, setup_context=setup_context)
567
            >>>
568
            >>> x = torch.randn(3, requires_grad=True)
569
            >>> y = numpy_mul(x, val=3.14)
570
            >>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y))
571
            >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
572

573
        """
574
        schema = self._opoverload._schema
575
        if not utils.is_functional_schema(schema):
576
            raise RuntimeError(
577
                f"Cannot register autograd formula for non-functional operator "
578
                f"{self} with schema {schema}. Please create "
579
                f"a functional operator and register an autograd formula for that."
580
            )
581

582
        self._backward_fn = backward
583
        self._setup_context_fn = setup_context
584

585
    def _register_to_dispatcher(self) -> None:
586
        lib = self._lib
587
        schema_str = self._name + self._schema
588
        cpp_schema = _C.parse_schema(schema_str)
589
        if utils.has_kwarg_only_tensors(cpp_schema):
590
            # If you want to support this, the progression is:
591
            # - supporting kwarg-only Tensors that are non-differentiable
592
            # - supporting kwarg-only Tensors (regardless of differentiability)
593
            raise NotImplementedError(
594
                f"custom_op with kwarg-only Tensor args. Please make your "
595
                f"tensors not kwarg-only. Got: {schema_str}"
596
            )
597

598
        lib.define(
599
            schema_str,
600
            tags=[_C.Tag.pt2_compliant_tag, _C.Tag.needs_fixed_stride_order],
601
        )
602
        self._opoverload = utils.lookup_op(self._qualname)
603

604
        def fake_impl(*args, **kwargs):
605
            if self._abstract_fn is None:
606
                if utils.can_generate_trivial_fake_impl(self._opoverload):
607
                    return None
608
                raise RuntimeError(
609
                    f"There was no fake impl registered for {self}. "
610
                    f"This is necessary for torch.compile/export/fx tracing to work. "
611
                    f"Please use `{self._init_fn.__name__}.register_fake` to add an "
612
                    f"fake impl."
613
                )
614
            return self._abstract_fn(*args, **kwargs)
615

616
        lib._register_fake(self._name, fake_impl, _stacklevel=4)
617

618
        autograd_impl = autograd.make_autograd_impl(self._opoverload, self)
619
        lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True)
620

621
        schema = self._opoverload._schema
622
        if schema.is_mutable:
623

624
            def adinplaceorview_impl(keyset, *args, **kwargs):
625
                for arg, val in utils.zip_schema(schema, args, kwargs):
626
                    if not arg.alias_info:
627
                        continue
628
                    if not arg.alias_info.is_write:
629
                        continue
630
                    if isinstance(val, Tensor):
631
                        torch.autograd.graph.increment_version(val)
632
                    elif isinstance(val, (tuple, list)):
633
                        for v in val:
634
                            if isinstance(v, Tensor):
635
                                torch.autograd.graph.increment_version(v)
636
                with _C._AutoDispatchBelowADInplaceOrView():
637
                    return self._opoverload.redispatch(
638
                        keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs
639
                    )
640

641
            lib.impl(
642
                self._name,
643
                adinplaceorview_impl,
644
                "ADInplaceOrView",
645
                with_keyset=True,
646
            )
647

648
    def _register_backend_select_dispatcher(self, device_arg_index: int):
649
        """
650
        Switch on the device argument to select the correct backend to dispatch to.
651
        """
652

653
        def backend_select(keyset, *args, **kwargs):
654
            device = args[device_arg_index].type
655
            if device not in self._backend_fns:
656
                raise RuntimeError(
657
                    f"{self._name} does not have a kernel registered for {device}. "
658
                    "Please use register_kernel to do so."
659
                )
660
            dispatch_key = _C._dispatch_key_for_device(device)
661
            dispatch_key = getattr(_C.DispatchKey, dispatch_key)
662
            return self._opoverload.redispatch(
663
                _C.DispatchKeySet(dispatch_key), *args, **kwargs
664
            )
665

666
        self._lib.impl(self._name, backend_select, "BackendSelect", with_keyset=True)
667

668
    def __call__(self, *args, **kwargs):
669
        return self._opoverload(*args, **kwargs)
670

671
    def register_vmap(
672
        self,
673
        func: Optional[Callable] = None,
674
    ):
675
        r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op.
676

677
        This API may be used as a decorator.
678

679
        In order for an operator to work with :func:`torch.vmap`, you may need to register a
680
        vmap implementation in the following signature:
681

682
            ``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``,
683

684
        where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``.
685

686
        It specifies how do we compute the batched version of ``op`` given inputs with an additional
687
        dimension (specified by ``in_dims``).
688

689
        For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None``
690
        if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer
691
        specifying what dimension of the Tensor is being vmapped over.
692

693
        ``info`` is a collection of additional metadata that may be helpful:
694
        ``info.batch_size`` specifies the size of the dimension being vmapped over, while
695
        ``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`.
696

697
        The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``,
698
        ``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim``
699
        per output that specifies if the output has the vmapped dimension and what index it is in.
700

701
        Examples:
702
            >>> import torch
703
            >>> import numpy as np
704
            >>> from torch import Tensor
705
            >>> from typing import Tuple
706
            >>>
707
            >>> def to_numpy(tensor):
708
            >>>     return tensor.cpu().numpy()
709
            >>>
710
            >>> lib = torch.library.Library("mylib", "FRAGMENT")
711
            >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=())
712
            >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]:
713
            >>>     x_np = to_numpy(x)
714
            >>>     dx = torch.tensor(3 * x_np ** 2, device=x.device)
715
            >>>     return torch.tensor(x_np ** 3, device=x.device), dx
716
            >>>
717
            >>> def numpy_cube_vmap(info, in_dims, x):
718
            >>>     result = numpy_cube(x)
719
            >>>     return result, (in_dims[0], in_dims[0])
720
            >>>
721
            >>> numpy_cube.register_vmap(numpy_cube_vmap)
722
            >>>
723
            >>> x = torch.randn(3)
724
            >>> torch.vmap(numpy_cube)(x)
725
            >>>
726
            >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
727
            >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
728
            >>>     return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
729
            >>>
730
            >>> @numpy_mul.register_vmap
731
            >>> def numpy_mul_vmap(info, in_dims, x, y):
732
            >>>     x_bdim, y_bdim = in_dims
733
            >>>     x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
734
            >>>     y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
735
            >>>     result = x * y
736
            >>>     result = result.movedim(-1, 0)
737
            >>>     return result, 0
738
            >>>
739
            >>>
740
            >>> x = torch.randn(3)
741
            >>> y = torch.randn(3)
742
            >>> torch.vmap(numpy_mul)(x, y)
743
        """
744
        from torch._functorch.autograd_function import custom_function_call_vmap_helper
745
        from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter
746

747
        def register(func):
748
            need_register = self._vmap_fn is None
749
            self._vmap_fn = func
750

751
            if need_register:
752

753
                def wrapped_func(keyset, *args, **kwargs):
754
                    interpreter = retrieve_current_functorch_interpreter()
755
                    return custom_function_call_vmap_helper(
756
                        interpreter, self._vmap_fn, self._opoverload, *args, **kwargs
757
                    )
758

759
                self._lib.impl(
760
                    self._name, wrapped_func, "FuncTorchBatched", with_keyset=True
761
                )
762

763
        if func is None:
764
            return register
765
        else:
766
            return register(func)
767

768

769
# NOTE: [Supporting decorator and non-decorator usage]
770
#
771
# Some APIs may be both used as a decorator and not as a decorator.
772
# For example:
773
#
774
# >>> def fn(x):
775
# >>>     return x.sin()
776
# >>>
777
# >>> # Usage 1: not as a decorator
778
# >>> numpy_sin.register_kernel("cuda", fn)
779
# >>>
780
# >>> # Usage 2: as a decorator
781
# >>> @numpy_sin.register_kernel("cuda")
782
# >>> def fn2(x):
783
# >>>     return x.sin
784
#
785
# The way we support this is that `register_kernel` accepts an optional `fn`.
786
# If `fn` is provided (Usage 1), then we know that the user is using it not
787
# as a decorator.
788
# If `fn` is not provided (Usage 2), then `register_kernel` needs to return a
789
# decorator.
790

791

792
OPDEF_TO_LIB: Dict[str, "torch.library.Library"] = {}
793
OPDEFS: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
794

795

796
def get_library_allowing_overwrite(
797
    namespace: str, name: str
798
) -> "torch.library.Library":
799
    qualname = f"{namespace}::{name}"
800

801
    if qualname in OPDEF_TO_LIB:
802
        OPDEF_TO_LIB[qualname]._destroy()
803
        del OPDEF_TO_LIB[qualname]
804

805
    lib = torch.library.Library(namespace, "FRAGMENT")  # noqa: TOR901
806
    OPDEF_TO_LIB[qualname] = lib
807
    return lib
808

809

810
def iter_tensors(
811
    args: Tuple[Any], kwargs: Dict[str, Any], allowed_nesting: int = 1
812
) -> Iterator[Tensor]:
813
    def check(arg):
814
        if isinstance(arg, Tensor):
815
            yield arg
816
        elif allowed_nesting > 0 and isinstance(arg, (tuple, list)):
817
            yield from iter_tensors(tuple(arg), {}, allowed_nesting - 1)
818

819
    for arg in args:
820
        yield from check(arg)
821
    for kwarg in kwargs.values():
822
        yield from check(kwarg)
823

824

825
def _maybe_get_opdef(
826
    op: Union[CustomOpDef, _ops.OpOverload, str]
827
) -> Optional[CustomOpDef]:
828
    if isinstance(op, CustomOpDef):
829
        return op
830
    if isinstance(op, _ops.OpOverload):
831
        op = op._name
832
    assert isinstance(op, str)
833
    if op in OPDEFS:
834
        return OPDEFS[op]
835
    return None
836

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.