pytorch

Форк
0
976 строк · 39.0 Кб
1
import dataclasses
2
import functools
3
import inspect
4
import sys
5
import typing
6
import weakref
7

8
from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseType, ListType, BaseTy
9

10
import torch
11
import torch._C as _C
12
import torch.library as library
13
from torch._library.abstract_impl import AbstractImplCtx
14
from torch.library import get_ctx
15

16
from .autograd import autograd_kernel_indirection, construct_autograd_kernel
17

18
"""
19
For a detailed guide on custom ops, please see
20
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
21

22
This file includes pieces of the implementation of our custom operator API.
23
"""
24

25
__all__ = ["custom_op", "CustomOp", "get_ctx", "AbstractImplCtx"]
26

27

28
SUPPORTED_DEVICE_TYPE_TO_KEY = {
29
    "cpu": "CPU",
30
    "cuda": "CUDA",
31
}
32

33
# We will not let users register CustomOps with anything that could look like
34
# PyTorch internals to avoid confusion.
35
RESERVED_NS = {
36
    "prim",
37
    "prims",
38
    "aten",
39
    "at",
40
    "torch",
41
    "pytorch",
42
}
43

44

45
def custom_op(
46
    qualname: str, manual_schema: typing.Optional[str] = None
47
) -> typing.Callable:
48
    r"""Creates a new CustomOp object.
49

50
    WARNING: if you're a user, please do not use this directly
51
    (instead use the torch._custom_ops APIs).
52
    Also please see the following for a detailed guide on custom ops.
53
    https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
54

55
    In PyTorch, defining an op (short for "operator") is a two step-process:
56
    - we need to define (create) the op
57
    - we need to implement behavior for how the operator interacts with
58
      various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
59

60
    This entrypoint defines the CustomOp object (the first step);
61
    you must then perform the second step by calling various methods on
62
    the CustomOp object.
63

64
    This API is used as a decorator (see examples).
65

66
    Arguments:
67
        qualname (str): Should be a string that looks like
68
            "namespace::operator_name". Operators in PyTorch need a namespace to
69
            avoid name collisions; a given operator may only be created once.
70
            If you are writing a Python library, we recommend the namespace to
71
            be the name of your top-level module. The operator_name must be
72
            the same as the name of the function you pass to custom_op
73
            (see examples).
74
        manual_schema (Optional[str]): Each PyTorch operator needs a schema that
75
            tells PyTorch the types of the inputs/outputs. If None (default),
76
            we will infer the schema from the type annotations on the function
77
            (see examples). Otherwise, if you don't want to use type annotations,
78
            you may provide us the schema string.
79

80
    Example::
81
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
82
        >>> import numpy as np
83
        >>> from torch import Tensor
84
        >>>
85
        >>> # Step 1: define the CustomOp.
86
        >>> # We need to provide the decorator a "prototype function"
87
        >>> # (a function with Python ellipses as the body).
88
        >>> @custom_op("my_library::numpy_sin")
89
        >>> def numpy_sin(x: Tensor) -> Tensor:
90
        >>>     ...
91
        >>>
92
        >>> # numpy_sin is now an instance of class CustomOp
93
        >>> print(type(numpy_sin))
94
        >>>
95
        >>> # Step 2: Register an implementation for various PyTorch subsystems
96
        >>>
97
        >>> # Register an implementation for CPU tensors
98
        >>> @numpy_sin.impl('cpu')
99
        >>> def numpy_sin_impl_cpu(x):
100
        >>>     return torch.from_numpy(np.sin(x.numpy()))
101
        >>>
102
        >>> # Register an implementation for CUDA tensors
103
        >>> @numpy_sin.impl('cuda')
104
        >>> def numpy_sin_impl_cuda(x):
105
        >>>     return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)
106
        >>>
107
        >>> x = torch.randn(3)
108
        >>> numpy_sin(x)  # calls numpy_sin_impl_cpu
109
        >>>
110
        >>> x_cuda = x.cuda()
111
        >>> numpy_sin(x)  # calls numpy_sin_impl_cuda
112

113
    """
114

115
    def inner(func):
116
        if not inspect.isfunction(func):
117
            raise ValueError(
118
                f"custom_op(...)(func): Expected `func` to be a Python "
119
                f"function, got: {type(func)}"
120
            )
121

122
        ns, name = parse_qualname(qualname)
123
        validate_namespace(ns)
124
        if func.__name__ != name:
125
            raise ValueError(
126
                f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
127
                f"to have name '{name}' but got '{func.__name__}'. "
128
                f"Please either change the name of `func` or the qualname that "
129
                f"is passed to `custom_op`"
130
            )
131

132
        schema = infer_schema(func) if manual_schema is None else manual_schema
133
        schema_str = f"{name}{schema}"
134
        function_schema = FunctionSchema.parse(schema_str)
135
        validate_schema(function_schema)
136
        if manual_schema is not None:
137
            validate_function_matches_schema(function_schema, func)
138

139
        lib = library.Library(ns, "FRAGMENT")
140
        lib.define(schema_str)
141
        ophandle = find_ophandle_or_throw(ns, function_schema.name)
142
        result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
143

144
        result.__name__ = func.__name__
145
        result.__module__ = func.__module__
146
        result.__doc__ = func.__doc__
147

148
        library.impl(lib, result._opname, "Autograd")(
149
            autograd_kernel_indirection(weakref.proxy(result))
150
        )
151

152
        torch._C._dispatch_set_report_error_callback(
153
            ophandle, functools.partial(report_error_callback, weakref.proxy(result))
154
        )
155

156
        return result
157

158
    return inner
159

160

161
# Global dictionary holding references to all CustomOp objects
162
# Yes, it keeps all CustomOps alive (see NOTE [CustomOp lifetime])
163
# Used to query the CustomOp associated with a specific C++ dispatcher operator.
164
# An example usage is FakeTensor: FakeTensor checks if a specific operator
165
# has an implementation registered via the CustomOp API.
166
# Indexed by qualname (e.g. aten::foo)
167
global_registry: typing.Dict[str, "CustomOp"] = {}
168

169

170
class CustomOp:
171
    r"""Class for custom operators in PyTorch.
172

173
    Use the CustomOp API to create user-defined custom operators that behave
174
    just like regular PyTorch operators (e.g. torch.sin, torch.mm) when it
175
    comes to various PyTorch subsystems (like torch.compile).
176

177
    To construct a `CustomOp`, use `custom_op`.
178
    """
179

180
    def __init__(self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False):
181
        super().__init__()
182
        if not _private_access:
183
            raise RuntimeError(
184
                "The CustomOp constructor is private and we do not guarantee "
185
                "BC for it. Please use custom_op(...) to create a CustomOp object"
186
            )
187
        name = f"{cpp_ns}::{operator_name}"
188
        self._schema = schema
189
        self._cpp_ns = cpp_ns
190
        self._lib: library.Library = lib
191
        self._ophandle: _C._DispatchOperatorHandle = ophandle
192
        # Has the name of the op, e.g. "foo". We cache here for convenience.
193
        self._opname: str = operator_name
194
        # this is _opname but with namespace. e.g. "custom::foo"
195
        self._qualname: str = name
196
        self.__name__ = None  # mypy requires this
197
        # NB: Some of these impls are registered as kernels to DispatchKeys.
198
        # Modifying the _impls dict directly won't do anything in that case.
199
        self._impls: typing.Dict[str, typing.Optional[FuncAndLocation]] = {}
200
        # See NOTE [CustomOp autograd kernel indirection]
201
        self._registered_autograd_kernel_indirection = False
202

203
        global_registry[self._qualname] = self
204

205
    def _register_autograd_kernel_indirection(self):
206
        assert not self._registered_autograd_kernel_indirection
207
        self._lib.impl(self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd")
208
        self._registered_autograd_kernel_indirection = True
209

210
    # Records the impl and the source location in self._impls
211
    # Note that this doesn't cause torch.library to use the impl, that
212
    # needs to be done in a separate self._lib.impl call.
213
    def _register_impl(self, kind, func, stacklevel=2):
214
        if self._has_impl(kind):
215
            func_and_location = self._impls[kind]
216
            assert func_and_location is not None  # Pacify mypy
217
            location = func_and_location.location
218
            raise RuntimeError(
219
                f"Attempting to register a {kind} impl for operator {self._qualname} "
220
                f"that already has a {kind} impl registered from Python at "
221
                f"{location}. This is not supported."
222
            )
223
        frame = inspect.getframeinfo(sys._getframe(stacklevel))
224
        location = f"{frame.filename}:{frame.lineno}"
225
        self._impls[kind] = FuncAndLocation(func, location)
226

227
    def _get_impl(self, kind):
228
        return self._impls[kind]
229

230
    def _has_impl(self, kind):
231
        return kind in self._impls
232

233
    def _destroy(self):
234
        # NOTE: [CustomOp lifetime]
235
        # A CustomOp, once created, lives forever. The mechanism is that the
236
        # global registry holds a reference to it. However, to make testing
237
        # easier, we want to be able to destroy CustomOp objects.
238
        # CustomOp._destroy does the job, though it leaves the CustomOp
239
        # in a garbage state.
240
        del self._lib
241

242
        opnamespace = getattr(torch.ops, self._cpp_ns)
243
        if hasattr(opnamespace, self._opname):
244
            delattr(opnamespace, self._opname)
245

246
        del global_registry[self._qualname]
247

248
    def __repr__(self):
249
        return f'<CustomOp(op="{self._qualname}")>'
250

251
    def __call__(self, *args, **kwargs):
252
        # Bypass torch.ops.* and directly do OperatorHandle::callBoxed.
253
        # Using torch.ops.* is a bit of a pain (it can be slow and it has lifetime
254
        # issues from caching operators that make testing CustomOp difficult).
255
        result = _C._dispatch_call_boxed(self._ophandle, *args, **kwargs)
256
        return result
257

258
    def impl(
259
        self, device_types: typing.Union[str, typing.Iterable[str]], _stacklevel=2,
260
    ) -> typing.Callable:
261
        r"""Register an implementation for a device type for this CustomOp object.
262

263
        WARNING: if you're a user, please do not use this directly
264
        (instead use the torch._custom_ops APIs).
265
        Also please see the following for a detailed guide on custom ops.
266
        https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
267

268
        If the CustomOp is passed multiple Tensor inputs with different device
269
        types, it will dispatch to the registered implementation for the highest
270
        priority device type among those present.
271
        The supported device types, in order of priority, are {'cuda', 'cpu'}.
272

273
        This API is used as a decorator (see examples).
274

275
        Arguments:
276
            device_types (str or Iterable[str]): the device type(s) to register the function for.
277

278
        Examples::
279
            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
280
            >>> import numpy as np
281
            >>> from torch import Tensor
282
            >>>
283
            >>> @custom_op("my_library::numpy_cos")
284
            >>> def numpy_cos(x: Tensor) -> Tensor:
285
            >>>     ...
286
            >>>
287
            >>> # Register an implementation for CPU Tensors
288
            >>> @numpy_cos.impl('cpu')
289
            >>> def numpy_cos_impl_cpu(x):
290
            >>>     return torch.from_numpy(np.cos(x.numpy()))
291
            >>>
292
            >>> # Register an implementation for CUDA Tensors
293
            >>> @numpy_cos.impl('cuda')
294
            >>> def numpy_cos_impl_cuda(x):
295
            >>>     return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device)
296
            >>>
297
            >>> x = torch.randn(3)
298
            >>> numpy_cos(x)  # calls numpy_cos_impl_cpu
299
            >>>
300
            >>> x_cuda = x.cuda()
301
            >>> numpy_cos(x)  # calls numpy_cos_impl_cuda
302

303
        """
304
        if isinstance(device_types, str):
305
            device_types = [device_types]
306
        for device_type in device_types:
307
            validate_device_type(device_type)
308

309
        def inner(f):
310
            for device_type in set(device_types):
311
                self._check_doesnt_have_library_impl(device_type)
312
                self._register_impl(device_type, f, stacklevel=_stacklevel)
313
                dispatch_key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
314
                library.impl(self._lib, self._opname, dispatch_key)(f)
315
            return f
316

317
        return inner
318

319
    def _check_doesnt_have_library_impl(self, device_type):
320
        if self._has_impl(device_type):
321
            return
322
        key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
323
        if _C._dispatch_has_computed_kernel_for_dispatch_key(self._qualname, key):
324
            raise RuntimeError(
325
                f"impl(..., device_types={device_type}): the operator {self._qualname} "
326
                f"already has an implementation for this device type via a "
327
                f"pre-existing torch.library or TORCH_LIBRARY registration.")
328

329
    def impl_factory(self) -> typing.Callable:
330
        r"""Register an implementation for a factory function."""
331

332
        def inner(f):
333
            self._register_impl("factory", f)
334
            library.impl(self._lib, self._opname, "BackendSelect")(f)
335
            return f
336

337
        return inner
338

339
    def impl_abstract(self, _stacklevel=2) -> typing.Callable:
340
        r"""Register an abstract implementation for this operator.
341

342
        WARNING: please do not use this directly (and instead use the torch._custom_ops
343
        APIs). Also please see the following for a detailed guide on custom ops.
344
        https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
345

346
        An "abstract implementation" specifies the behavior of this operator on
347
        Tensors that carry no data. Given some input Tensors with certain properties
348
        (sizes/strides/storage_offset/device), it specifies what the properties of
349
        the output Tensors are.
350

351
        The abstract implementation has the same signature as the operator.
352
        It is run for both FakeTensors and meta tensors. To write an abstract
353
        implementation, assume that all Tensor inputs to the operator are
354
        regular CPU/CUDA/Meta tensors, but they do not have storage, and
355
        you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
356
        The abstract implementation must consist of only PyTorch operations
357
        (and may not directly access the storage or data of any input or
358
        intermediate Tensors).
359

360
        This API is used as a decorator (see examples).
361

362
        Examples::
363
            >>> import numpy as np
364
            >>> from torch import Tensor
365
            >>>
366
            >>> # Example 1: an operator without data-dependent output shape
367
            >>> @custom_op('my_library::custom_linear')
368
            >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
369
            >>>     ...
370
            >>>
371
            >>> @custom_linear.impl_abstract()
372
            >>> def custom_linear_abstract(x, weight):
373
            >>>     assert x.dim() == 2
374
            >>>     assert weight.dim() == 2
375
            >>>     assert bias.dim() == 1
376
            >>>     assert x.shape[1] == weight.shape[1]
377
            >>>     assert weight.shape[0] == bias.shape[0]
378
            >>>     assert x.device == weight.device
379
            >>>
380
            >>>     return (x @ weight.t()) + bias
381
            >>>
382
            >>> # Example 2: an operator with data-dependent output shape
383
            >>> @custom_op('my_library::custom_nonzero')
384
            >>> def custom_nonzero(x: Tensor) -> Tensor:
385
            >>>     ...
386
            >>>
387
            >>> @custom_nonzero.impl_abstract()
388
            >>> def custom_nonzero_abstract(x):
389
            >>>     # Number of nonzero-elements is data-dependent.
390
            >>>     # Since we cannot peek at the data in an abstract impl,
391
            >>>     # we use the ctx object to construct a new symint that
392
            >>>     # represents the data-dependent size.
393
            >>>     ctx = torch._custom_op.get_ctx()
394
            >>>     nnz = ctx.create_unbacked_symint()
395
            >>>     shape = [x.dim(), nnz]
396
            >>>     result = x.new_empty(shape, dtype=torch.long)
397
            >>>     return result
398
            >>>
399
            >>> @custom_nonzero.impl(['cpu', 'cuda'])
400
            >>> def custom_nonzero_impl(x):
401
            >>>     x_np = to_numpy(x)
402
            >>>     res = np.stack(np.nonzero(x_np), axis=1)
403
            >>>     # unbacked symbolic ints in PyTorch must be >= 2, so we
404
            >>>     # constrain the range to at least 2
405
            >>>     if res.shape[0] <= 1:
406
            >>>         raise RuntimeError("not supported")
407
            >>>     return torch.tensor(res, device=x.device)
408

409
        """
410

411
        def inner(f):
412
            self._check_doesnt_have_library_meta_impl()
413
            self._register_impl("abstract", f, stacklevel=_stacklevel)
414
            location = self._get_impl("abstract").location
415

416
            qualname = self._qualname
417

418
            # Handle DispatchKey.Meta registration
419
            @functools.wraps(f)
420
            def f_with_ctx(*args, **kwargs):
421
                def error_on_ctx():
422
                    raise RuntimeError(
423
                        f"Attempted to call get_ctx() for the meta implementation "
424
                        f"for {qualname}."
425
                        f"You have presumably called get_ctx() because the operator "
426
                        f"has a data-dependent output shape; if so, there is no "
427
                        f"such meta implementation and this error is the correct "
428
                        f"behavior. Otherwise, please remove the call to get_ctx() "
429
                        f"in the implementation registered with impl_abstract "
430
                        f"at {location}"
431
                    )
432

433
                with torch._library.abstract_impl.set_ctx_getter(error_on_ctx):
434
                    return f(*args, **kwargs)
435

436
            self._lib.impl(self._opname, f_with_ctx, "Meta")
437
            return f
438

439
        return inner
440

441
    def _check_can_register_backward(self):
442
        def error(detail):
443
            raise RuntimeError(
444
                f"Cannot use torch._custom_ops APIs to register backward "
445
                f"formula for {detail}. Got operator "
446
                f"{self._qualname} with schema: {schema}"
447
            )
448

449
        schema = self._schema
450
        if schema.kind() != SchemaKind.functional:
451
            error("non-functional operator")
452

453
        rets = schema.returns
454
        if not schema.returns:
455
            error("operator with no returns")
456

457
        assert len(rets) > 0
458
        is_non_mutating_view = any(
459
            r.annotation is not None and not r.annotation.is_write for r in rets
460
        )
461
        if is_non_mutating_view:
462
            error("operator that returns views")
463

464
        # We make assumptions about the schema's return types.
465
        allowed_return_types = {
466
            BaseType(BaseTy.int): "int",
467
            BaseType(BaseTy.SymInt): "SymInt",
468
            BaseType(BaseTy.bool): "bool",
469
            BaseType(BaseTy.float): "float",
470
            BaseType(BaseTy.Tensor): "Tensor",
471
            ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]",
472
        }
473
        for ret in schema.returns:
474
            if ret.type in allowed_return_types:
475
                continue
476
            error(f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})")
477

478
    def _check_doesnt_have_library_autograd_impl(self):
479
        if self._registered_autograd_kernel_indirection:
480
            return
481

482
        if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
483
            raise RuntimeError(
484
                f"impl_backward/impl_save_for_backward: the operator {self._qualname} "
485
                f"already has an implementation for this device type via a "
486
                f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
487
                f"CompositeImplicitAutograd operators do not need an autograd formula; "
488
                f"instead, the operator will decompose into its constituents and those "
489
                f"can have autograd formulas defined on them.")
490

491
        # We can improve this by adding "all Autograd<BACKEND> keys", but
492
        # realistically people will just be using this API for CPU/CUDA for now.
493
        for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]:
494
            if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key):
495
                raise RuntimeError(
496
                    f"impl_backward/impl_save_for_backward: "
497
                    f"the operator {self._qualname} already has an Autograd kernel "
498
                    f"registered to DispatchKey::{key} vi a pre-existing "
499
                    f"torch.library or TORCH_LIBRARY registration. Please either "
500
                    f"remove those registrations or don't use the torch._custom_ops APIs")
501

502
    def _check_doesnt_have_library_meta_impl(self):
503
        if self._has_impl("abstract"):
504
            return
505

506
        # If the user's operator is CompositeExplicitAutograd,
507
        # allow them to impl_abstract. This is being pragmatic
508
        # (existing custom ops may have CompositeExplicitAutograd
509
        # registration that don't work with Meta kernels, so this
510
        # gives them an escape hatch).
511
        if (
512
            _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeExplicitAutograd")
513
            and not _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta")
514
        ):
515
            return
516

517
        # Otherwise, if the user's already has a Meta kernel or their
518
        # op is CompositeImplicitAutograd or some other alias dispatch key,
519
        # raise.
520

521
        # Special case for CompositeImplicitAutograd
522
        if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
523
            raise RuntimeError(
524
                f"impl_abstract(...): the operator {self._qualname} "
525
                f"already has an implementation for this device type via a "
526
                f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
527
                f"CompositeImplicitAutograd operators do not need an abstract impl; "
528
                f"instead, the operator will decompose into its constituents and those "
529
                f"can have abstract impls defined on them.")
530

531
        if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"):
532
            raise RuntimeError(
533
                f"impl_abstract(...): the operator {self._qualname} "
534
                f"already has an DispatchKey::Meta implementation via a "
535
                f"pre-existing torch.library or TORCH_LIBRARY registration. "
536
                f"Please either remove that registration or don't call impl_abstract.")
537

538
    # NOTE ["backward", "save_for_backward", and "autograd"]
539
    # As a part of the explicit autograd API, a user must provide us
540
    # a "save_for_backward" function and a "backward" function.
541
    # When both of these have been provided, then we automatically
542
    # construct the "autograd" kernel.
543
    def _register_autograd_kernel(self):
544
        assert self._has_impl("backward")
545
        assert self._has_impl("save_for_backward")
546
        kernel = construct_autograd_kernel(
547
            self._schema,
548
            self._output_differentiability,
549
            self,
550
            get_op(self._qualname),
551
            self._get_impl("save_for_backward").func,
552
            self._get_impl("backward").func)
553
        self._register_impl("autograd", kernel)
554

555
    def impl_save_for_backward(self, _stacklevel=2):
556
        r"""Register a function that tells us what to save for backward.
557

558
        Please see impl_backward for more details.
559
        """
560
        def inner(f):
561
            self._check_can_register_backward()
562
            self._check_doesnt_have_library_autograd_impl()
563
            if not self._registered_autograd_kernel_indirection:
564
                self._register_autograd_kernel_indirection()
565
            self._register_impl("save_for_backward", f, stacklevel=_stacklevel)
566
            if self._has_impl("backward"):
567
                self._register_autograd_kernel()
568
        return inner
569

570
    def impl_backward(self, output_differentiability=None, _stacklevel=2):
571
        r"""Registers a backward formula.
572

573
        WARNING: if you're a user, please do not use this directly
574
        (instead use the torch._custom_ops APIs).
575
        Also please see the following for a detailed guide on custom ops.
576
        https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
577

578
        In order for the CustomOp to work with autograd, you need to register
579
        a backward formula. There are two pieces to this:
580
        1. You must give us a function to specify what to save for backward.
581
           Call this the "save for backward" function.
582
        2. You must give us a function that computes gradients. Call this the
583
           "backward" function.
584

585
        Use `impl_save_for_backward` to define a "save for backward" function
586
        that specifies what gets saved for backward. The function should accept
587
        two arguments ``(inputs, output)`` and return the quantities to be saved
588
        for backward.
589

590
        During runtime, when you call the CustomOp, PyTorch will invoke the
591
        "save for backward" function with the inputs and output of the CustomOp.
592

593
        Use `impl_backward` to define the "backward" function. The backward
594
        function must accept ``(ctx, saved, *grads)``:
595
        - ``ctx`` is a context object where we may provide information
596
        - ``saved`` is exactly what gets returned from the "save for backward"
597
          function
598
        - ``grads`` is one or more gradients. The number of gradients matches
599
          the number of outputs of the CustomOp.
600

601
        The backward function must return a dict that maps the name of
602
        an input to the CustomOp to its corresponding gradient. All inputs that
603
        were declared to be Tensors in the CustomOp definition must be accounted
604
        for in the dict. The gradient may be a Tensor or None.
605

606
        """
607
        if output_differentiability is not None:
608
            def yell():
609
                raise RuntimeError(
610
                    f"impl_backward(output_differentiability): expected "
611
                    f"output_differentiability to be a list of bools with "
612
                    f"length equal to the number of outputs of this CustomOp "
613
                    f"got: {output_differentiability}")
614

615
            if not isinstance(output_differentiability, list):
616
                yell()
617
            for diff in output_differentiability:
618
                if not isinstance(diff, bool):
619
                    yell()
620
            if len(self._schema.returns) != len(output_differentiability):
621
                yell()
622

623
        def inner(f):
624
            self._check_can_register_backward()
625
            self._check_doesnt_have_library_autograd_impl()
626
            if not self._registered_autograd_kernel_indirection:
627
                self._register_autograd_kernel_indirection()
628
            self._register_impl("backward", f, stacklevel=_stacklevel)
629
            self._output_differentiability = output_differentiability
630
            if self._has_impl("save_for_backward"):
631
                self._register_autograd_kernel()
632
        return inner
633

634

635
@dataclasses.dataclass
636
class FuncAndLocation:
637
    func: typing.Callable
638
    location: str
639

640

641
def find_ophandle_or_throw(cpp_ns: str, operator_name: OperatorName):
642
    overload_name = (
643
        "" if operator_name.overload_name is None else operator_name.overload_name
644
    )
645
    return _C._dispatch_find_schema_or_throw(
646
        f"{cpp_ns}::{str(operator_name.name)}", overload_name
647
    )
648

649

650
def validate_namespace(ns: str) -> None:
651
    if "." in ns:
652
        raise ValueError(
653
            f'custom_op(..., ns="{ns}"): expected ns to not contain any . (and be a '
654
            f"valid variable name)"
655
        )
656
    if ns in RESERVED_NS:
657
        raise ValueError(
658
            f"custom_op(..., ns='{ns}'): '{ns}' is a reserved namespace, "
659
            f"please choose something else. "
660
        )
661

662
def validate_schema(schema: FunctionSchema) -> None:
663
    if not torch._library.utils.is_functional_schema(schema):
664
        raise ValueError(
665
            f"custom_op only supports functional operators "
666
            f"(ops that do not mutate any inputs, do not return "
667
            f"views of the inputs, and has at least one return). "
668
            f"Got the following non-functional schema: {schema}"
669
        )
670

671
    # For simplicity: don't allow self arguments
672
    if schema.arguments.self_arg is not None:
673
        raise ValueError(
674
            f"custom_op does not support arguments named 'self'. Please "
675
            f"rename your argument. Got: {schema}"
676
        )
677

678

679
def parse_qualname(qualname: str) -> typing.Tuple[str, str]:
680
    names = qualname.split("::", 1)
681
    if len(names) != 2:
682
        raise ValueError(f"Expected there to be a namespace in {qualname}, i.e. The "
683
                         f"operator name should look something like ns::foo")
684
    if '.' in names[1]:
685
        raise ValueError(f"The torch.custom_ops APIs do not handle overloads, "
686
                         f"i.e. operator names with '.' in them. "
687
                         f"Please name your operator something like ns::foo. "
688
                         f"Got: {qualname}")
689
    return names[0], names[1]
690

691

692
def validate_device_type(device_type: str) -> None:
693
    if device_type not in SUPPORTED_DEVICE_TYPE_TO_KEY:
694
        raise ValueError(
695
            f"CustomOp.impl(device_types=[{device_type}, ...]): we only support device_type "
696
            f"in {SUPPORTED_DEVICE_TYPE_TO_KEY.keys()}."
697
        )
698

699

700
def supported_param(param: inspect.Parameter) -> bool:
701
    return param.kind in (
702
        inspect.Parameter.POSITIONAL_OR_KEYWORD,
703
        inspect.Parameter.KEYWORD_ONLY,
704
    )
705

706

707
def validate_function_matches_schema(
708
    schema: FunctionSchema, func: typing.Callable
709
) -> None:
710
    sig = inspect.signature(func)
711

712
    if not all(supported_param(p) for _, p in sig.parameters.items()):
713
        raise ValueError(
714
            f"custom_op(..., manual_schema)(func): positional-only args, "
715
            f"varargs, and kwargs are not supported. Please rewrite `func` "
716
            f"to not have them. Got `func` with signature: {sig}"
717
        )
718

719
    if (
720
        any(
721
            p.annotation is not inspect.Parameter.empty
722
            for _, p in sig.parameters.items()
723
        )
724
        or sig.return_annotation is not inspect.Signature.empty
725
    ):
726
        raise ValueError(
727
            f"custom_op(..., manual_schema)(func): When passing in a manual "
728
            f"schema, we expect `func` to have no type annotations to avoid "
729
            f"ambiguity. Got `func` with signature: {sig}"
730
        )
731

732
    positional = [
733
        (name, param)
734
        for name, param in sig.parameters.items()
735
        if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
736
    ]
737
    kwargonly = [
738
        (name, param)
739
        for name, param in sig.parameters.items()
740
        if param.kind == inspect.Parameter.KEYWORD_ONLY
741
    ]
742

743
    def error():
744
        raise ValueError(
745
            f"custom_op(..., manual_schema)(func): When passing in a manual "
746
            f"schema, we expect `func`'s signature to match `manual_schema` "
747
            f"(aside from type annotations). "
748
            f"func's signature: {sig}, manual_schema: {schema}"
749
        )
750

751
    def error_default_args():
752
        raise ValueError(
753
            f"custom_op(..., manual_schema)(func): "
754
            f"neither func nor manual_schema should have default "
755
            f"arguments. Got "
756
            f"func's signature: {sig}, manual_schema: {schema}"
757
        )
758

759
    def compare(sig_args, schema_args):
760
        if len(sig_args) != len(schema_args):
761
            error()
762
        for (name, param), arg in zip(sig_args, schema_args):
763
            if name != arg.name:
764
                error()
765
            if param.default is not inspect.Parameter.empty or arg.default is not None:
766
                error_default_args()
767

768
    compare(positional, schema.arguments.flat_positional)
769
    compare(kwargonly, schema.arguments.flat_kwarg_only)
770

771

772
def infer_schema(prototype_function: typing.Callable) -> str:
773
    sig = inspect.signature(prototype_function)
774

775
    def error_fn(what):
776
        raise ValueError(
777
            f"custom_op(...)(func): {what} " f"Got func with signature {sig})"
778
        )
779

780
    params = [
781
        parse_param(name, param, error_fn) for name, param in sig.parameters.items()
782
    ]
783
    ret = parse_return(sig.return_annotation, error_fn)
784
    return f"({', '.join(params)}) -> {ret}"
785

786

787
def parse_param(name, param, error_fn):
788
    if not supported_param(param):
789
        error_fn("We do not support positional-only args, varargs, or varkwargs.")
790

791
    if param.annotation is inspect.Parameter.empty:
792
        error_fn(f"Parameter {name} must have a type annotation.")
793

794
    if param.annotation not in SUPPORTED_PARAM_TYPES.keys():
795
        error_fn(
796
            f"Parameter {name} has unsupported type {param.annotation}. "
797
            f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
798
        )
799

800
    if param.default is not inspect.Parameter.empty:
801
        error_fn(
802
            f"Parameter {name} has a default value; this is not supported. "
803
            f"If you want to use default values then create a function with "
804
            f"default values that calls the CustomOp"
805
        )
806

807
    return f"{SUPPORTED_PARAM_TYPES[param.annotation]} {name}"
808

809

810
def derived_types(
811
    base_type, cpp_type, list_base, optional_base_list, optional_list_base
812
):
813
    result = [
814
        (base_type, cpp_type),
815
        (typing.Optional[base_type], f"{cpp_type}?"),
816
    ]
817
    if list_base:
818
        result.append((typing.Sequence[base_type], f"{cpp_type}[]"))  # type: ignore[valid-type]
819
    if optional_base_list:
820
        result.append((typing.Sequence[typing.Optional[base_type]], f"{cpp_type}?[]"))  # type: ignore[valid-type]
821
    if optional_list_base:
822
        result.append((typing.Optional[typing.Sequence[base_type]], f"{cpp_type}[]?"))  # type: ignore[valid-type]
823
    return result
824

825

826
def get_supported_param_types():
827
    data = [
828
        # (python type, schema type, type[] variant, type?[] variant, type[]? variant
829
        (torch.Tensor, "Tensor", True, True, False),
830
        (int, "SymInt", True, False, True),
831
        (float, "float", True, False, True),
832
        (bool, "bool", True, False, True),
833
        (str, "str", False, False, False),
834
        (torch.types.Number, "Scalar", True, False, False),
835
        (torch.dtype, "ScalarType", False, False, False),
836
        (torch.device, "Device", False, False, False),
837
    ]
838
    result = []
839
    for line in data:
840
        result.extend(derived_types(*line))
841
    return dict(result)
842

843

844
SUPPORTED_RETURN_TYPES = {
845
    torch.Tensor: "Tensor",
846
    typing.List[torch.Tensor]: "Tensor[]",
847
    int: "SymInt",
848
    float: "float",
849
    bool: "bool",
850
    torch.types.Number: "Scalar",
851
}
852

853

854
def parse_return(annotation, error_fn):
855
    origin = typing.get_origin(annotation)
856
    if origin is not tuple:
857
        if annotation not in SUPPORTED_RETURN_TYPES.keys():
858
            error_fn(
859
                f"Return has unsupported type {annotation}. "
860
                f"The valid types are: {SUPPORTED_RETURN_TYPES}."
861
            )
862
        return SUPPORTED_RETURN_TYPES[annotation]
863

864
    args = typing.get_args(annotation)
865
    for arg in args:
866
        if arg not in SUPPORTED_RETURN_TYPES:
867
            error_fn(
868
                f"Return has unsupported type {annotation}. "
869
                f"The valid types are: {SUPPORTED_RETURN_TYPES}."
870
            )
871

872
    return "(" + ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args]) + ")"
873

874

875
SUPPORTED_PARAM_TYPES = get_supported_param_types()
876

877

878
def report_error_callback(custom_op: typing.Any, key: str) -> None:
879
    if key == "Undefined":
880
        raise NotImplementedError(
881
            f"{custom_op}: There were no Tensor inputs to this operator "
882
            f"(e.g. you passed an empty list of Tensors). If your operator is a "
883
            f"factory function (that is, it takes no Tensors and constructs "
884
            f"a new one), then please use CustomOp.impl_factory to register "
885
            f"an implementation for it"
886
        )
887
    if key == "Meta":
888
        raise NotImplementedError(
889
            f"{custom_op}: when running with device='Meta' tensors: there is no "
890
            f"abstract impl registered for this CustomOp. Please register one via "
891
            f"CustomOp.impl_abstract to get this CustomOp to work with Meta tensors"
892
        )
893
    if key in ("CPU", "CUDA"):
894
        device = key.lower()
895
        raise NotImplementedError(
896
            f"{custom_op}: when running with device='{device}' tensors: there is no "
897
            f"{device} impl registered for this CustomOp. Please register one via "
898
            f"CustomOp.impl(device_type='{device}')"
899
        )
900
    raise NotImplementedError(
901
        f"{custom_op}: No implementation for dispatch key {key}. It is likely "
902
        f"that we have not added this functionality yet, please either open an "
903
        f"issue or if you're feeling adventurous, use the low-level "
904
        f"torch.library API"
905
    )
906

907

908
def custom_op_from_existing(op):
909
    ns = op.namespace
910
    lib = torch.library.Library(ns, "FRAGMENT")
911
    name = op.name().split("::")[-1]
912
    schema_str = str(op._schema)
913
    # CustomOp expects the schema string without the namespace
914
    schema_str = schema_str.split("::")[-1]
915
    schema = FunctionSchema.parse(schema_str)
916
    return CustomOp(lib, ns, schema, name, op, _private_access=True)
917

918

919
def get_op(qualname):
920
    def error_not_found():
921
        raise ValueError(
922
            f"Could not find the operator {qualname}. Please make sure you have "
923
            f"already registered the operator and (if registered from C++) "
924
            f"loaded it via torch.ops.load_library.")
925

926
    ns, name = parse_qualname(qualname)
927
    if not hasattr(torch.ops, ns):
928
        error_not_found()
929
    opnamespace = getattr(torch.ops, ns)
930
    if not hasattr(opnamespace, name):
931
        error_not_found()
932
    packet = getattr(opnamespace, name)
933
    if not hasattr(packet, 'default'):
934
        error_not_found()
935
    return packet.default
936

937

938
def _find_custom_op(qualname, also_check_torch_library=False):
939
    if qualname in global_registry:
940
        return global_registry[qualname]
941
    if not also_check_torch_library:
942
        raise RuntimeError(
943
            f"Could not find custom op \"{qualname}\". Did you register it via "
944
            f"the torch._custom_ops API?")
945
    overload = get_op(qualname)
946
    result = custom_op_from_existing(overload)
947
    return result
948

949

950
def get_abstract_impl(qualname):
951
    if qualname not in torch._custom_op.impl.global_registry:
952
        return None
953
    custom_op = torch._custom_op.impl.global_registry[qualname]
954
    if custom_op is None:
955
        return None
956
    if not custom_op._has_impl("abstract"):
957
        return None
958
    return custom_op._get_impl("abstract").func
959

960

961
def _custom_op_with_schema(qualname, schema, needs_fixed_stride_order=True):
962
    ns, name = qualname.split("::")
963
    schema_str = f"{name}{schema}"
964
    function_schema = FunctionSchema.parse(schema_str)
965
    validate_schema(function_schema)
966
    tags = [torch._C.Tag.needs_fixed_stride_order] if needs_fixed_stride_order else []
967
    lib = library.Library(ns, "FRAGMENT")
968
    lib.define(schema_str, tags=tags)
969
    ophandle = find_ophandle_or_throw(ns, function_schema.name)
970
    result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
971
    result._register_autograd_kernel_indirection()
972

973
    torch._C._dispatch_set_report_error_callback(
974
        ophandle, functools.partial(report_error_callback, weakref.proxy(result))
975
    )
976
    return get_op(qualname)
977

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.