pytorch
976 строк · 39.0 Кб
1import dataclasses
2import functools
3import inspect
4import sys
5import typing
6import weakref
7
8from torchgen.model import FunctionSchema, OperatorName, SchemaKind, BaseType, ListType, BaseTy
9
10import torch
11import torch._C as _C
12import torch.library as library
13from torch._library.abstract_impl import AbstractImplCtx
14from torch.library import get_ctx
15
16from .autograd import autograd_kernel_indirection, construct_autograd_kernel
17
18"""
19For a detailed guide on custom ops, please see
20https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
21
22This file includes pieces of the implementation of our custom operator API.
23"""
24
25__all__ = ["custom_op", "CustomOp", "get_ctx", "AbstractImplCtx"]
26
27
28SUPPORTED_DEVICE_TYPE_TO_KEY = {
29"cpu": "CPU",
30"cuda": "CUDA",
31}
32
33# We will not let users register CustomOps with anything that could look like
34# PyTorch internals to avoid confusion.
35RESERVED_NS = {
36"prim",
37"prims",
38"aten",
39"at",
40"torch",
41"pytorch",
42}
43
44
45def custom_op(
46qualname: str, manual_schema: typing.Optional[str] = None
47) -> typing.Callable:
48r"""Creates a new CustomOp object.
49
50WARNING: if you're a user, please do not use this directly
51(instead use the torch._custom_ops APIs).
52Also please see the following for a detailed guide on custom ops.
53https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
54
55In PyTorch, defining an op (short for "operator") is a two step-process:
56- we need to define (create) the op
57- we need to implement behavior for how the operator interacts with
58various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
59
60This entrypoint defines the CustomOp object (the first step);
61you must then perform the second step by calling various methods on
62the CustomOp object.
63
64This API is used as a decorator (see examples).
65
66Arguments:
67qualname (str): Should be a string that looks like
68"namespace::operator_name". Operators in PyTorch need a namespace to
69avoid name collisions; a given operator may only be created once.
70If you are writing a Python library, we recommend the namespace to
71be the name of your top-level module. The operator_name must be
72the same as the name of the function you pass to custom_op
73(see examples).
74manual_schema (Optional[str]): Each PyTorch operator needs a schema that
75tells PyTorch the types of the inputs/outputs. If None (default),
76we will infer the schema from the type annotations on the function
77(see examples). Otherwise, if you don't want to use type annotations,
78you may provide us the schema string.
79
80Example::
81>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
82>>> import numpy as np
83>>> from torch import Tensor
84>>>
85>>> # Step 1: define the CustomOp.
86>>> # We need to provide the decorator a "prototype function"
87>>> # (a function with Python ellipses as the body).
88>>> @custom_op("my_library::numpy_sin")
89>>> def numpy_sin(x: Tensor) -> Tensor:
90>>> ...
91>>>
92>>> # numpy_sin is now an instance of class CustomOp
93>>> print(type(numpy_sin))
94>>>
95>>> # Step 2: Register an implementation for various PyTorch subsystems
96>>>
97>>> # Register an implementation for CPU tensors
98>>> @numpy_sin.impl('cpu')
99>>> def numpy_sin_impl_cpu(x):
100>>> return torch.from_numpy(np.sin(x.numpy()))
101>>>
102>>> # Register an implementation for CUDA tensors
103>>> @numpy_sin.impl('cuda')
104>>> def numpy_sin_impl_cuda(x):
105>>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)
106>>>
107>>> x = torch.randn(3)
108>>> numpy_sin(x) # calls numpy_sin_impl_cpu
109>>>
110>>> x_cuda = x.cuda()
111>>> numpy_sin(x) # calls numpy_sin_impl_cuda
112
113"""
114
115def inner(func):
116if not inspect.isfunction(func):
117raise ValueError(
118f"custom_op(...)(func): Expected `func` to be a Python "
119f"function, got: {type(func)}"
120)
121
122ns, name = parse_qualname(qualname)
123validate_namespace(ns)
124if func.__name__ != name:
125raise ValueError(
126f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
127f"to have name '{name}' but got '{func.__name__}'. "
128f"Please either change the name of `func` or the qualname that "
129f"is passed to `custom_op`"
130)
131
132schema = infer_schema(func) if manual_schema is None else manual_schema
133schema_str = f"{name}{schema}"
134function_schema = FunctionSchema.parse(schema_str)
135validate_schema(function_schema)
136if manual_schema is not None:
137validate_function_matches_schema(function_schema, func)
138
139lib = library.Library(ns, "FRAGMENT")
140lib.define(schema_str)
141ophandle = find_ophandle_or_throw(ns, function_schema.name)
142result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
143
144result.__name__ = func.__name__
145result.__module__ = func.__module__
146result.__doc__ = func.__doc__
147
148library.impl(lib, result._opname, "Autograd")(
149autograd_kernel_indirection(weakref.proxy(result))
150)
151
152torch._C._dispatch_set_report_error_callback(
153ophandle, functools.partial(report_error_callback, weakref.proxy(result))
154)
155
156return result
157
158return inner
159
160
161# Global dictionary holding references to all CustomOp objects
162# Yes, it keeps all CustomOps alive (see NOTE [CustomOp lifetime])
163# Used to query the CustomOp associated with a specific C++ dispatcher operator.
164# An example usage is FakeTensor: FakeTensor checks if a specific operator
165# has an implementation registered via the CustomOp API.
166# Indexed by qualname (e.g. aten::foo)
167global_registry: typing.Dict[str, "CustomOp"] = {}
168
169
170class CustomOp:
171r"""Class for custom operators in PyTorch.
172
173Use the CustomOp API to create user-defined custom operators that behave
174just like regular PyTorch operators (e.g. torch.sin, torch.mm) when it
175comes to various PyTorch subsystems (like torch.compile).
176
177To construct a `CustomOp`, use `custom_op`.
178"""
179
180def __init__(self, lib, cpp_ns, schema, operator_name, ophandle, *, _private_access=False):
181super().__init__()
182if not _private_access:
183raise RuntimeError(
184"The CustomOp constructor is private and we do not guarantee "
185"BC for it. Please use custom_op(...) to create a CustomOp object"
186)
187name = f"{cpp_ns}::{operator_name}"
188self._schema = schema
189self._cpp_ns = cpp_ns
190self._lib: library.Library = lib
191self._ophandle: _C._DispatchOperatorHandle = ophandle
192# Has the name of the op, e.g. "foo". We cache here for convenience.
193self._opname: str = operator_name
194# this is _opname but with namespace. e.g. "custom::foo"
195self._qualname: str = name
196self.__name__ = None # mypy requires this
197# NB: Some of these impls are registered as kernels to DispatchKeys.
198# Modifying the _impls dict directly won't do anything in that case.
199self._impls: typing.Dict[str, typing.Optional[FuncAndLocation]] = {}
200# See NOTE [CustomOp autograd kernel indirection]
201self._registered_autograd_kernel_indirection = False
202
203global_registry[self._qualname] = self
204
205def _register_autograd_kernel_indirection(self):
206assert not self._registered_autograd_kernel_indirection
207self._lib.impl(self._opname, autograd_kernel_indirection(weakref.proxy(self)), "Autograd")
208self._registered_autograd_kernel_indirection = True
209
210# Records the impl and the source location in self._impls
211# Note that this doesn't cause torch.library to use the impl, that
212# needs to be done in a separate self._lib.impl call.
213def _register_impl(self, kind, func, stacklevel=2):
214if self._has_impl(kind):
215func_and_location = self._impls[kind]
216assert func_and_location is not None # Pacify mypy
217location = func_and_location.location
218raise RuntimeError(
219f"Attempting to register a {kind} impl for operator {self._qualname} "
220f"that already has a {kind} impl registered from Python at "
221f"{location}. This is not supported."
222)
223frame = inspect.getframeinfo(sys._getframe(stacklevel))
224location = f"{frame.filename}:{frame.lineno}"
225self._impls[kind] = FuncAndLocation(func, location)
226
227def _get_impl(self, kind):
228return self._impls[kind]
229
230def _has_impl(self, kind):
231return kind in self._impls
232
233def _destroy(self):
234# NOTE: [CustomOp lifetime]
235# A CustomOp, once created, lives forever. The mechanism is that the
236# global registry holds a reference to it. However, to make testing
237# easier, we want to be able to destroy CustomOp objects.
238# CustomOp._destroy does the job, though it leaves the CustomOp
239# in a garbage state.
240del self._lib
241
242opnamespace = getattr(torch.ops, self._cpp_ns)
243if hasattr(opnamespace, self._opname):
244delattr(opnamespace, self._opname)
245
246del global_registry[self._qualname]
247
248def __repr__(self):
249return f'<CustomOp(op="{self._qualname}")>'
250
251def __call__(self, *args, **kwargs):
252# Bypass torch.ops.* and directly do OperatorHandle::callBoxed.
253# Using torch.ops.* is a bit of a pain (it can be slow and it has lifetime
254# issues from caching operators that make testing CustomOp difficult).
255result = _C._dispatch_call_boxed(self._ophandle, *args, **kwargs)
256return result
257
258def impl(
259self, device_types: typing.Union[str, typing.Iterable[str]], _stacklevel=2,
260) -> typing.Callable:
261r"""Register an implementation for a device type for this CustomOp object.
262
263WARNING: if you're a user, please do not use this directly
264(instead use the torch._custom_ops APIs).
265Also please see the following for a detailed guide on custom ops.
266https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
267
268If the CustomOp is passed multiple Tensor inputs with different device
269types, it will dispatch to the registered implementation for the highest
270priority device type among those present.
271The supported device types, in order of priority, are {'cuda', 'cpu'}.
272
273This API is used as a decorator (see examples).
274
275Arguments:
276device_types (str or Iterable[str]): the device type(s) to register the function for.
277
278Examples::
279>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
280>>> import numpy as np
281>>> from torch import Tensor
282>>>
283>>> @custom_op("my_library::numpy_cos")
284>>> def numpy_cos(x: Tensor) -> Tensor:
285>>> ...
286>>>
287>>> # Register an implementation for CPU Tensors
288>>> @numpy_cos.impl('cpu')
289>>> def numpy_cos_impl_cpu(x):
290>>> return torch.from_numpy(np.cos(x.numpy()))
291>>>
292>>> # Register an implementation for CUDA Tensors
293>>> @numpy_cos.impl('cuda')
294>>> def numpy_cos_impl_cuda(x):
295>>> return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device)
296>>>
297>>> x = torch.randn(3)
298>>> numpy_cos(x) # calls numpy_cos_impl_cpu
299>>>
300>>> x_cuda = x.cuda()
301>>> numpy_cos(x) # calls numpy_cos_impl_cuda
302
303"""
304if isinstance(device_types, str):
305device_types = [device_types]
306for device_type in device_types:
307validate_device_type(device_type)
308
309def inner(f):
310for device_type in set(device_types):
311self._check_doesnt_have_library_impl(device_type)
312self._register_impl(device_type, f, stacklevel=_stacklevel)
313dispatch_key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
314library.impl(self._lib, self._opname, dispatch_key)(f)
315return f
316
317return inner
318
319def _check_doesnt_have_library_impl(self, device_type):
320if self._has_impl(device_type):
321return
322key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
323if _C._dispatch_has_computed_kernel_for_dispatch_key(self._qualname, key):
324raise RuntimeError(
325f"impl(..., device_types={device_type}): the operator {self._qualname} "
326f"already has an implementation for this device type via a "
327f"pre-existing torch.library or TORCH_LIBRARY registration.")
328
329def impl_factory(self) -> typing.Callable:
330r"""Register an implementation for a factory function."""
331
332def inner(f):
333self._register_impl("factory", f)
334library.impl(self._lib, self._opname, "BackendSelect")(f)
335return f
336
337return inner
338
339def impl_abstract(self, _stacklevel=2) -> typing.Callable:
340r"""Register an abstract implementation for this operator.
341
342WARNING: please do not use this directly (and instead use the torch._custom_ops
343APIs). Also please see the following for a detailed guide on custom ops.
344https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
345
346An "abstract implementation" specifies the behavior of this operator on
347Tensors that carry no data. Given some input Tensors with certain properties
348(sizes/strides/storage_offset/device), it specifies what the properties of
349the output Tensors are.
350
351The abstract implementation has the same signature as the operator.
352It is run for both FakeTensors and meta tensors. To write an abstract
353implementation, assume that all Tensor inputs to the operator are
354regular CPU/CUDA/Meta tensors, but they do not have storage, and
355you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
356The abstract implementation must consist of only PyTorch operations
357(and may not directly access the storage or data of any input or
358intermediate Tensors).
359
360This API is used as a decorator (see examples).
361
362Examples::
363>>> import numpy as np
364>>> from torch import Tensor
365>>>
366>>> # Example 1: an operator without data-dependent output shape
367>>> @custom_op('my_library::custom_linear')
368>>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
369>>> ...
370>>>
371>>> @custom_linear.impl_abstract()
372>>> def custom_linear_abstract(x, weight):
373>>> assert x.dim() == 2
374>>> assert weight.dim() == 2
375>>> assert bias.dim() == 1
376>>> assert x.shape[1] == weight.shape[1]
377>>> assert weight.shape[0] == bias.shape[0]
378>>> assert x.device == weight.device
379>>>
380>>> return (x @ weight.t()) + bias
381>>>
382>>> # Example 2: an operator with data-dependent output shape
383>>> @custom_op('my_library::custom_nonzero')
384>>> def custom_nonzero(x: Tensor) -> Tensor:
385>>> ...
386>>>
387>>> @custom_nonzero.impl_abstract()
388>>> def custom_nonzero_abstract(x):
389>>> # Number of nonzero-elements is data-dependent.
390>>> # Since we cannot peek at the data in an abstract impl,
391>>> # we use the ctx object to construct a new symint that
392>>> # represents the data-dependent size.
393>>> ctx = torch._custom_op.get_ctx()
394>>> nnz = ctx.create_unbacked_symint()
395>>> shape = [x.dim(), nnz]
396>>> result = x.new_empty(shape, dtype=torch.long)
397>>> return result
398>>>
399>>> @custom_nonzero.impl(['cpu', 'cuda'])
400>>> def custom_nonzero_impl(x):
401>>> x_np = to_numpy(x)
402>>> res = np.stack(np.nonzero(x_np), axis=1)
403>>> # unbacked symbolic ints in PyTorch must be >= 2, so we
404>>> # constrain the range to at least 2
405>>> if res.shape[0] <= 1:
406>>> raise RuntimeError("not supported")
407>>> return torch.tensor(res, device=x.device)
408
409"""
410
411def inner(f):
412self._check_doesnt_have_library_meta_impl()
413self._register_impl("abstract", f, stacklevel=_stacklevel)
414location = self._get_impl("abstract").location
415
416qualname = self._qualname
417
418# Handle DispatchKey.Meta registration
419@functools.wraps(f)
420def f_with_ctx(*args, **kwargs):
421def error_on_ctx():
422raise RuntimeError(
423f"Attempted to call get_ctx() for the meta implementation "
424f"for {qualname}."
425f"You have presumably called get_ctx() because the operator "
426f"has a data-dependent output shape; if so, there is no "
427f"such meta implementation and this error is the correct "
428f"behavior. Otherwise, please remove the call to get_ctx() "
429f"in the implementation registered with impl_abstract "
430f"at {location}"
431)
432
433with torch._library.abstract_impl.set_ctx_getter(error_on_ctx):
434return f(*args, **kwargs)
435
436self._lib.impl(self._opname, f_with_ctx, "Meta")
437return f
438
439return inner
440
441def _check_can_register_backward(self):
442def error(detail):
443raise RuntimeError(
444f"Cannot use torch._custom_ops APIs to register backward "
445f"formula for {detail}. Got operator "
446f"{self._qualname} with schema: {schema}"
447)
448
449schema = self._schema
450if schema.kind() != SchemaKind.functional:
451error("non-functional operator")
452
453rets = schema.returns
454if not schema.returns:
455error("operator with no returns")
456
457assert len(rets) > 0
458is_non_mutating_view = any(
459r.annotation is not None and not r.annotation.is_write for r in rets
460)
461if is_non_mutating_view:
462error("operator that returns views")
463
464# We make assumptions about the schema's return types.
465allowed_return_types = {
466BaseType(BaseTy.int): "int",
467BaseType(BaseTy.SymInt): "SymInt",
468BaseType(BaseTy.bool): "bool",
469BaseType(BaseTy.float): "float",
470BaseType(BaseTy.Tensor): "Tensor",
471ListType(BaseType(BaseTy.Tensor), None): "List[Tensor]",
472}
473for ret in schema.returns:
474if ret.type in allowed_return_types:
475continue
476error(f"operator with return not in {list(allowed_return_types.values())} (got {ret.type})")
477
478def _check_doesnt_have_library_autograd_impl(self):
479if self._registered_autograd_kernel_indirection:
480return
481
482if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
483raise RuntimeError(
484f"impl_backward/impl_save_for_backward: the operator {self._qualname} "
485f"already has an implementation for this device type via a "
486f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
487f"CompositeImplicitAutograd operators do not need an autograd formula; "
488f"instead, the operator will decompose into its constituents and those "
489f"can have autograd formulas defined on them.")
490
491# We can improve this by adding "all Autograd<BACKEND> keys", but
492# realistically people will just be using this API for CPU/CUDA for now.
493for key in ["Autograd", "AutogradCPU", "AutogradCUDA"]:
494if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, key):
495raise RuntimeError(
496f"impl_backward/impl_save_for_backward: "
497f"the operator {self._qualname} already has an Autograd kernel "
498f"registered to DispatchKey::{key} vi a pre-existing "
499f"torch.library or TORCH_LIBRARY registration. Please either "
500f"remove those registrations or don't use the torch._custom_ops APIs")
501
502def _check_doesnt_have_library_meta_impl(self):
503if self._has_impl("abstract"):
504return
505
506# If the user's operator is CompositeExplicitAutograd,
507# allow them to impl_abstract. This is being pragmatic
508# (existing custom ops may have CompositeExplicitAutograd
509# registration that don't work with Meta kernels, so this
510# gives them an escape hatch).
511if (
512_C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeExplicitAutograd")
513and not _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta")
514):
515return
516
517# Otherwise, if the user's already has a Meta kernel or their
518# op is CompositeImplicitAutograd or some other alias dispatch key,
519# raise.
520
521# Special case for CompositeImplicitAutograd
522if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "CompositeImplicitAutograd"):
523raise RuntimeError(
524f"impl_abstract(...): the operator {self._qualname} "
525f"already has an implementation for this device type via a "
526f"pre-existing registration to DispatchKey::CompositeImplicitAutograd."
527f"CompositeImplicitAutograd operators do not need an abstract impl; "
528f"instead, the operator will decompose into its constituents and those "
529f"can have abstract impls defined on them.")
530
531if _C._dispatch_has_kernel_for_dispatch_key(self._qualname, "Meta"):
532raise RuntimeError(
533f"impl_abstract(...): the operator {self._qualname} "
534f"already has an DispatchKey::Meta implementation via a "
535f"pre-existing torch.library or TORCH_LIBRARY registration. "
536f"Please either remove that registration or don't call impl_abstract.")
537
538# NOTE ["backward", "save_for_backward", and "autograd"]
539# As a part of the explicit autograd API, a user must provide us
540# a "save_for_backward" function and a "backward" function.
541# When both of these have been provided, then we automatically
542# construct the "autograd" kernel.
543def _register_autograd_kernel(self):
544assert self._has_impl("backward")
545assert self._has_impl("save_for_backward")
546kernel = construct_autograd_kernel(
547self._schema,
548self._output_differentiability,
549self,
550get_op(self._qualname),
551self._get_impl("save_for_backward").func,
552self._get_impl("backward").func)
553self._register_impl("autograd", kernel)
554
555def impl_save_for_backward(self, _stacklevel=2):
556r"""Register a function that tells us what to save for backward.
557
558Please see impl_backward for more details.
559"""
560def inner(f):
561self._check_can_register_backward()
562self._check_doesnt_have_library_autograd_impl()
563if not self._registered_autograd_kernel_indirection:
564self._register_autograd_kernel_indirection()
565self._register_impl("save_for_backward", f, stacklevel=_stacklevel)
566if self._has_impl("backward"):
567self._register_autograd_kernel()
568return inner
569
570def impl_backward(self, output_differentiability=None, _stacklevel=2):
571r"""Registers a backward formula.
572
573WARNING: if you're a user, please do not use this directly
574(instead use the torch._custom_ops APIs).
575Also please see the following for a detailed guide on custom ops.
576https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
577
578In order for the CustomOp to work with autograd, you need to register
579a backward formula. There are two pieces to this:
5801. You must give us a function to specify what to save for backward.
581Call this the "save for backward" function.
5822. You must give us a function that computes gradients. Call this the
583"backward" function.
584
585Use `impl_save_for_backward` to define a "save for backward" function
586that specifies what gets saved for backward. The function should accept
587two arguments ``(inputs, output)`` and return the quantities to be saved
588for backward.
589
590During runtime, when you call the CustomOp, PyTorch will invoke the
591"save for backward" function with the inputs and output of the CustomOp.
592
593Use `impl_backward` to define the "backward" function. The backward
594function must accept ``(ctx, saved, *grads)``:
595- ``ctx`` is a context object where we may provide information
596- ``saved`` is exactly what gets returned from the "save for backward"
597function
598- ``grads`` is one or more gradients. The number of gradients matches
599the number of outputs of the CustomOp.
600
601The backward function must return a dict that maps the name of
602an input to the CustomOp to its corresponding gradient. All inputs that
603were declared to be Tensors in the CustomOp definition must be accounted
604for in the dict. The gradient may be a Tensor or None.
605
606"""
607if output_differentiability is not None:
608def yell():
609raise RuntimeError(
610f"impl_backward(output_differentiability): expected "
611f"output_differentiability to be a list of bools with "
612f"length equal to the number of outputs of this CustomOp "
613f"got: {output_differentiability}")
614
615if not isinstance(output_differentiability, list):
616yell()
617for diff in output_differentiability:
618if not isinstance(diff, bool):
619yell()
620if len(self._schema.returns) != len(output_differentiability):
621yell()
622
623def inner(f):
624self._check_can_register_backward()
625self._check_doesnt_have_library_autograd_impl()
626if not self._registered_autograd_kernel_indirection:
627self._register_autograd_kernel_indirection()
628self._register_impl("backward", f, stacklevel=_stacklevel)
629self._output_differentiability = output_differentiability
630if self._has_impl("save_for_backward"):
631self._register_autograd_kernel()
632return inner
633
634
635@dataclasses.dataclass
636class FuncAndLocation:
637func: typing.Callable
638location: str
639
640
641def find_ophandle_or_throw(cpp_ns: str, operator_name: OperatorName):
642overload_name = (
643"" if operator_name.overload_name is None else operator_name.overload_name
644)
645return _C._dispatch_find_schema_or_throw(
646f"{cpp_ns}::{str(operator_name.name)}", overload_name
647)
648
649
650def validate_namespace(ns: str) -> None:
651if "." in ns:
652raise ValueError(
653f'custom_op(..., ns="{ns}"): expected ns to not contain any . (and be a '
654f"valid variable name)"
655)
656if ns in RESERVED_NS:
657raise ValueError(
658f"custom_op(..., ns='{ns}'): '{ns}' is a reserved namespace, "
659f"please choose something else. "
660)
661
662def validate_schema(schema: FunctionSchema) -> None:
663if not torch._library.utils.is_functional_schema(schema):
664raise ValueError(
665f"custom_op only supports functional operators "
666f"(ops that do not mutate any inputs, do not return "
667f"views of the inputs, and has at least one return). "
668f"Got the following non-functional schema: {schema}"
669)
670
671# For simplicity: don't allow self arguments
672if schema.arguments.self_arg is not None:
673raise ValueError(
674f"custom_op does not support arguments named 'self'. Please "
675f"rename your argument. Got: {schema}"
676)
677
678
679def parse_qualname(qualname: str) -> typing.Tuple[str, str]:
680names = qualname.split("::", 1)
681if len(names) != 2:
682raise ValueError(f"Expected there to be a namespace in {qualname}, i.e. The "
683f"operator name should look something like ns::foo")
684if '.' in names[1]:
685raise ValueError(f"The torch.custom_ops APIs do not handle overloads, "
686f"i.e. operator names with '.' in them. "
687f"Please name your operator something like ns::foo. "
688f"Got: {qualname}")
689return names[0], names[1]
690
691
692def validate_device_type(device_type: str) -> None:
693if device_type not in SUPPORTED_DEVICE_TYPE_TO_KEY:
694raise ValueError(
695f"CustomOp.impl(device_types=[{device_type}, ...]): we only support device_type "
696f"in {SUPPORTED_DEVICE_TYPE_TO_KEY.keys()}."
697)
698
699
700def supported_param(param: inspect.Parameter) -> bool:
701return param.kind in (
702inspect.Parameter.POSITIONAL_OR_KEYWORD,
703inspect.Parameter.KEYWORD_ONLY,
704)
705
706
707def validate_function_matches_schema(
708schema: FunctionSchema, func: typing.Callable
709) -> None:
710sig = inspect.signature(func)
711
712if not all(supported_param(p) for _, p in sig.parameters.items()):
713raise ValueError(
714f"custom_op(..., manual_schema)(func): positional-only args, "
715f"varargs, and kwargs are not supported. Please rewrite `func` "
716f"to not have them. Got `func` with signature: {sig}"
717)
718
719if (
720any(
721p.annotation is not inspect.Parameter.empty
722for _, p in sig.parameters.items()
723)
724or sig.return_annotation is not inspect.Signature.empty
725):
726raise ValueError(
727f"custom_op(..., manual_schema)(func): When passing in a manual "
728f"schema, we expect `func` to have no type annotations to avoid "
729f"ambiguity. Got `func` with signature: {sig}"
730)
731
732positional = [
733(name, param)
734for name, param in sig.parameters.items()
735if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
736]
737kwargonly = [
738(name, param)
739for name, param in sig.parameters.items()
740if param.kind == inspect.Parameter.KEYWORD_ONLY
741]
742
743def error():
744raise ValueError(
745f"custom_op(..., manual_schema)(func): When passing in a manual "
746f"schema, we expect `func`'s signature to match `manual_schema` "
747f"(aside from type annotations). "
748f"func's signature: {sig}, manual_schema: {schema}"
749)
750
751def error_default_args():
752raise ValueError(
753f"custom_op(..., manual_schema)(func): "
754f"neither func nor manual_schema should have default "
755f"arguments. Got "
756f"func's signature: {sig}, manual_schema: {schema}"
757)
758
759def compare(sig_args, schema_args):
760if len(sig_args) != len(schema_args):
761error()
762for (name, param), arg in zip(sig_args, schema_args):
763if name != arg.name:
764error()
765if param.default is not inspect.Parameter.empty or arg.default is not None:
766error_default_args()
767
768compare(positional, schema.arguments.flat_positional)
769compare(kwargonly, schema.arguments.flat_kwarg_only)
770
771
772def infer_schema(prototype_function: typing.Callable) -> str:
773sig = inspect.signature(prototype_function)
774
775def error_fn(what):
776raise ValueError(
777f"custom_op(...)(func): {what} " f"Got func with signature {sig})"
778)
779
780params = [
781parse_param(name, param, error_fn) for name, param in sig.parameters.items()
782]
783ret = parse_return(sig.return_annotation, error_fn)
784return f"({', '.join(params)}) -> {ret}"
785
786
787def parse_param(name, param, error_fn):
788if not supported_param(param):
789error_fn("We do not support positional-only args, varargs, or varkwargs.")
790
791if param.annotation is inspect.Parameter.empty:
792error_fn(f"Parameter {name} must have a type annotation.")
793
794if param.annotation not in SUPPORTED_PARAM_TYPES.keys():
795error_fn(
796f"Parameter {name} has unsupported type {param.annotation}. "
797f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
798)
799
800if param.default is not inspect.Parameter.empty:
801error_fn(
802f"Parameter {name} has a default value; this is not supported. "
803f"If you want to use default values then create a function with "
804f"default values that calls the CustomOp"
805)
806
807return f"{SUPPORTED_PARAM_TYPES[param.annotation]} {name}"
808
809
810def derived_types(
811base_type, cpp_type, list_base, optional_base_list, optional_list_base
812):
813result = [
814(base_type, cpp_type),
815(typing.Optional[base_type], f"{cpp_type}?"),
816]
817if list_base:
818result.append((typing.Sequence[base_type], f"{cpp_type}[]")) # type: ignore[valid-type]
819if optional_base_list:
820result.append((typing.Sequence[typing.Optional[base_type]], f"{cpp_type}?[]")) # type: ignore[valid-type]
821if optional_list_base:
822result.append((typing.Optional[typing.Sequence[base_type]], f"{cpp_type}[]?")) # type: ignore[valid-type]
823return result
824
825
826def get_supported_param_types():
827data = [
828# (python type, schema type, type[] variant, type?[] variant, type[]? variant
829(torch.Tensor, "Tensor", True, True, False),
830(int, "SymInt", True, False, True),
831(float, "float", True, False, True),
832(bool, "bool", True, False, True),
833(str, "str", False, False, False),
834(torch.types.Number, "Scalar", True, False, False),
835(torch.dtype, "ScalarType", False, False, False),
836(torch.device, "Device", False, False, False),
837]
838result = []
839for line in data:
840result.extend(derived_types(*line))
841return dict(result)
842
843
844SUPPORTED_RETURN_TYPES = {
845torch.Tensor: "Tensor",
846typing.List[torch.Tensor]: "Tensor[]",
847int: "SymInt",
848float: "float",
849bool: "bool",
850torch.types.Number: "Scalar",
851}
852
853
854def parse_return(annotation, error_fn):
855origin = typing.get_origin(annotation)
856if origin is not tuple:
857if annotation not in SUPPORTED_RETURN_TYPES.keys():
858error_fn(
859f"Return has unsupported type {annotation}. "
860f"The valid types are: {SUPPORTED_RETURN_TYPES}."
861)
862return SUPPORTED_RETURN_TYPES[annotation]
863
864args = typing.get_args(annotation)
865for arg in args:
866if arg not in SUPPORTED_RETURN_TYPES:
867error_fn(
868f"Return has unsupported type {annotation}. "
869f"The valid types are: {SUPPORTED_RETURN_TYPES}."
870)
871
872return "(" + ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args]) + ")"
873
874
875SUPPORTED_PARAM_TYPES = get_supported_param_types()
876
877
878def report_error_callback(custom_op: typing.Any, key: str) -> None:
879if key == "Undefined":
880raise NotImplementedError(
881f"{custom_op}: There were no Tensor inputs to this operator "
882f"(e.g. you passed an empty list of Tensors). If your operator is a "
883f"factory function (that is, it takes no Tensors and constructs "
884f"a new one), then please use CustomOp.impl_factory to register "
885f"an implementation for it"
886)
887if key == "Meta":
888raise NotImplementedError(
889f"{custom_op}: when running with device='Meta' tensors: there is no "
890f"abstract impl registered for this CustomOp. Please register one via "
891f"CustomOp.impl_abstract to get this CustomOp to work with Meta tensors"
892)
893if key in ("CPU", "CUDA"):
894device = key.lower()
895raise NotImplementedError(
896f"{custom_op}: when running with device='{device}' tensors: there is no "
897f"{device} impl registered for this CustomOp. Please register one via "
898f"CustomOp.impl(device_type='{device}')"
899)
900raise NotImplementedError(
901f"{custom_op}: No implementation for dispatch key {key}. It is likely "
902f"that we have not added this functionality yet, please either open an "
903f"issue or if you're feeling adventurous, use the low-level "
904f"torch.library API"
905)
906
907
908def custom_op_from_existing(op):
909ns = op.namespace
910lib = torch.library.Library(ns, "FRAGMENT")
911name = op.name().split("::")[-1]
912schema_str = str(op._schema)
913# CustomOp expects the schema string without the namespace
914schema_str = schema_str.split("::")[-1]
915schema = FunctionSchema.parse(schema_str)
916return CustomOp(lib, ns, schema, name, op, _private_access=True)
917
918
919def get_op(qualname):
920def error_not_found():
921raise ValueError(
922f"Could not find the operator {qualname}. Please make sure you have "
923f"already registered the operator and (if registered from C++) "
924f"loaded it via torch.ops.load_library.")
925
926ns, name = parse_qualname(qualname)
927if not hasattr(torch.ops, ns):
928error_not_found()
929opnamespace = getattr(torch.ops, ns)
930if not hasattr(opnamespace, name):
931error_not_found()
932packet = getattr(opnamespace, name)
933if not hasattr(packet, 'default'):
934error_not_found()
935return packet.default
936
937
938def _find_custom_op(qualname, also_check_torch_library=False):
939if qualname in global_registry:
940return global_registry[qualname]
941if not also_check_torch_library:
942raise RuntimeError(
943f"Could not find custom op \"{qualname}\". Did you register it via "
944f"the torch._custom_ops API?")
945overload = get_op(qualname)
946result = custom_op_from_existing(overload)
947return result
948
949
950def get_abstract_impl(qualname):
951if qualname not in torch._custom_op.impl.global_registry:
952return None
953custom_op = torch._custom_op.impl.global_registry[qualname]
954if custom_op is None:
955return None
956if not custom_op._has_impl("abstract"):
957return None
958return custom_op._get_impl("abstract").func
959
960
961def _custom_op_with_schema(qualname, schema, needs_fixed_stride_order=True):
962ns, name = qualname.split("::")
963schema_str = f"{name}{schema}"
964function_schema = FunctionSchema.parse(schema_str)
965validate_schema(function_schema)
966tags = [torch._C.Tag.needs_fixed_stride_order] if needs_fixed_stride_order else []
967lib = library.Library(ns, "FRAGMENT")
968lib.define(schema_str, tags=tags)
969ophandle = find_ophandle_or_throw(ns, function_schema.name)
970result = CustomOp(lib, ns, function_schema, name, ophandle, _private_access=True)
971result._register_autograd_kernel_indirection()
972
973torch._C._dispatch_set_report_error_callback(
974ophandle, functools.partial(report_error_callback, weakref.proxy(result))
975)
976return get_op(qualname)
977