1
from __future__ import annotations
6
from dataclasses import dataclass
7
from enum import auto, Enum
8
from typing import Callable, Iterator, Sequence
10
from torchgen.utils import assert_never, NamespaceHelper, OrderedSet
39
@dataclass(frozen=True)
44
def __str__(self) -> str:
45
return f"{self.file}:{self.line}"
55
DEFAULT_KERNEL_NAMESPACE = "at::native"
58
BACKEND_COMPONENTS = "CPU CUDA HIP XLA MTIA MPS IPU XPU HPU VE Lazy Meta PrivateUse1 PrivateUse2 PrivateUse3".split()
70
AUTOGRAD_KEYS = ["AutogradNestedTensor"] + [
71
"Autograd" + component for component in BACKEND_COMPONENTS
74
FRAGMENT_NAMESPACES = {"quantized", "quantized_decomposed"}
79
class DispatchKey(Enum):
91
CustomRNGKeyId = auto()
98
PythonTLSSnapshot = auto()
100
PythonDispatcher = auto()
102
FuncTorchDynamicLayerBackMode = auto()
106
BackendSelect = auto()
108
AutogradOther = auto()
109
AutogradFunctionality = auto()
110
AutogradNestedTensor = auto()
114
AutocastCUDA = auto()
117
FuncTorchGradWrapper = auto()
118
FuncTorchBatched = auto()
119
BatchedNestedTensor = auto()
120
FuncTorchVmapMode = auto()
121
FuncTorchDynamicLayerFrontMode = auto()
122
Functionalize = auto()
123
TESTING_ONLY_GenericWrapper = auto()
124
TESTING_ONLY_GenericMode = auto()
126
ADInplaceOrView = auto()
128
CompositeImplicitAutograd = auto()
129
CompositeImplicitAutogradNestedTensor = auto()
130
CompositeExplicitAutograd = auto()
131
CompositeExplicitAutogradNonFunctional = auto()
132
FuncTorchBatchedDecomposition = auto()
150
QuantizedCPU = auto()
151
QuantizedCUDA = auto()
152
QuantizedHIP = auto()
153
QuantizedXLA = auto()
154
QuantizedMTIA = auto()
155
QuantizedMPS = auto()
156
QuantizedIPU = auto()
157
QuantizedXPU = auto()
158
QuantizedHPU = auto()
160
QuantizedLazy = auto()
161
QuantizedMeta = auto()
162
QuantizedPrivateUse1 = auto()
163
QuantizedPrivateUse2 = auto()
164
QuantizedPrivateUse3 = auto()
177
SparsePrivateUse1 = auto()
178
SparsePrivateUse2 = auto()
179
SparsePrivateUse3 = auto()
180
SparseCsrCPU = auto()
181
SparseCsrCUDA = auto()
182
SparseCsrHIP = auto()
183
SparseCsrXLA = auto()
184
SparseCsrMTIA = auto()
185
SparseCsrMPS = auto()
186
SparseCsrIPU = auto()
187
SparseCsrXPU = auto()
188
SparseCsrHPU = auto()
190
SparseCsrLazy = auto()
191
SparseCsrMeta = auto()
192
SparseCsrPrivateUse1 = auto()
193
SparseCsrPrivateUse2 = auto()
194
SparseCsrPrivateUse3 = auto()
195
NestedTensorCPU = auto()
196
NestedTensorCUDA = auto()
197
NestedTensorHIP = auto()
198
NestedTensorXLA = auto()
199
NestedTensorMTIA = auto()
200
NestedTensorMPS = auto()
201
NestedTensorIPU = auto()
202
NestedTensorXPU = auto()
203
NestedTensorHPU = auto()
204
NestedTensorVE = auto()
205
NestedTensorLazy = auto()
206
NestedTensorMeta = auto()
207
NestedTensorPrivateUse1 = auto()
208
NestedTensorPrivateUse2 = auto()
209
NestedTensorPrivateUse3 = auto()
211
AutogradCUDA = auto()
214
AutogradMTIA = auto()
220
AutogradLazy = auto()
221
AutogradMeta = auto()
222
AutogradPrivateUse1 = auto()
223
AutogradPrivateUse2 = auto()
224
AutogradPrivateUse3 = auto()
227
def __str__(self) -> str:
230
def lower(self) -> str:
231
return str(self).lower()
234
def parse(value: str) -> DispatchKey:
235
for k, v in DispatchKey.__members__.items():
238
raise AssertionError(f"unknown dispatch key {value}")
241
class _TorchDispatchModeKey(Enum):
247
def codegen_per_backend_entries() -> str:
249
for fk in FUNCTIONALITY_KEYS:
250
for bc in BACKEND_COMPONENTS:
251
r.append(f" {fk}{bc} = auto()")
255
for fk in FUNCTIONALITY_KEYS:
256
for bc in BACKEND_COMPONENTS:
257
if not hasattr(DispatchKey, fk + bc):
258
r = codegen_per_backend_entries()
261
f"Missing {fk}{bc} from DispatchKey enum. Here is the autogenerated list we expect to have:\n\n{r}"
265
STRUCTURED_DISPATCH_KEYS = {
271
UFUNC_DISPATCH_KEYS = {DispatchKey.CUDA, DispatchKey.CPU}
276
DispatchKey.SparseCPU,
277
DispatchKey.SparseCsrCPU,
278
DispatchKey.MkldnnCPU,
282
DispatchKey.SparseCUDA,
283
DispatchKey.SparseCsrCUDA,
284
DispatchKey.QuantizedCPU,
285
DispatchKey.QuantizedCUDA,
286
DispatchKey.CompositeImplicitAutograd,
287
DispatchKey.CompositeImplicitAutogradNestedTensor,
288
DispatchKey.CompositeExplicitAutograd,
289
DispatchKey.CompositeExplicitAutogradNonFunctional,
290
DispatchKey.NestedTensorCPU,
291
DispatchKey.NestedTensorCUDA,
295
DispatchKey.SparseMeta,
296
DispatchKey.SparseCsrMeta,
297
DispatchKey.QuantizedMeta,
298
DispatchKey.NestedTensorMeta,
299
DispatchKey.ZeroTensor,
305
def is_generic_dispatch_key(dk: DispatchKey) -> bool:
307
DispatchKey.CompositeExplicitAutograd,
308
DispatchKey.CompositeExplicitAutogradNonFunctional,
309
DispatchKey.CompositeImplicitAutograd,
310
DispatchKey.CompositeImplicitAutogradNestedTensor,
315
def is_cuda_dispatch_key(dk: DispatchKey) -> bool:
318
DispatchKey.QuantizedCUDA,
319
DispatchKey.SparseCUDA,
320
DispatchKey.SparseCsrCUDA,
321
DispatchKey.NestedTensorCUDA,
322
DispatchKey.AutogradCUDA,
328
def is_structured_dispatch_key(dk: DispatchKey) -> bool:
329
return dk in STRUCTURED_DISPATCH_KEYS
332
def is_ufunc_dispatch_key(dk: DispatchKey) -> bool:
334
return dk in UFUNC_DISPATCH_KEYS
338
class ScalarType(Enum):
348
ComplexFloat = auto()
349
ComplexDouble = auto()
353
Float8_e5m2fnuz = auto()
354
Float8_e4m3fn = auto()
355
Float8_e4m3fnuz = auto()
357
def __str__(self) -> str:
361
def maybe_parse(value: str) -> ScalarType | None:
362
for k, v in ScalarType.__members__.items():
368
def parse(value: str) -> ScalarType:
369
mb_r = ScalarType.maybe_parse(value)
370
assert mb_r is not None, f"unknown dtype {value}"
374
def parse_set(values: str) -> OrderedSet[ScalarType]:
375
dtypes: OrderedSet[ScalarType] = OrderedSet()
376
for value in values.split(", "):
377
if value in DTYPE_CLASSES:
378
dtypes.update(DTYPE_CLASSES[value])
380
dtypes.add(ScalarType.parse(value))
384
DTYPE_CLASSES: dict[str, OrderedSet[ScalarType]] = {}
386
DTYPE_CLASSES["Integral"] = OrderedSet(
396
DTYPE_CLASSES["Floating"] = OrderedSet([ScalarType.Float, ScalarType.Double])
397
DTYPE_CLASSES["Complex"] = OrderedSet(
398
[ScalarType.ComplexFloat, ScalarType.ComplexDouble]
400
DTYPE_CLASSES["All"] = DTYPE_CLASSES["Integral"] | DTYPE_CLASSES["Floating"]
401
DTYPE_CLASSES["AllAndComplex"] = DTYPE_CLASSES["All"] | DTYPE_CLASSES["Complex"]
402
DTYPE_CLASSES["FloatingAndComplex"] = (
403
DTYPE_CLASSES["Floating"] | DTYPE_CLASSES["Complex"]
415
CUDAFunctorOnOther = auto()
416
CUDAFunctorOnSelf = auto()
426
def __str__(self) -> str:
430
def parse(value: str) -> UfuncKey:
431
for k, v in UfuncKey.__members__.items():
434
raise AssertionError(f"unknown ufunc key {value}")
437
class DeviceCheckType(Enum):
442
class ViewSchemaKind(Enum):
444
aliasing_inplace = auto()
445
non_aliasing = auto()
460
@dataclass(frozen=True)
477
use_const_ref_for_mutable_tensors: bool
483
device_check: DeviceCheckType
486
python_module: str | None
489
category_override: str | None
493
variants: set[Variant]
498
manual_kernel_registration: bool
504
manual_cpp_binding: bool
516
autogen: list[OperatorName]
520
ufunc_inner_loop: dict[UfuncKey, UfuncInnerLoop]
533
structured_delegate: OperatorName | None
539
structured_inherits: str | None
547
precomputed: Precompute | None
551
cpp_no_default_args: set[str]
564
has_composite_implicit_autograd_kernel: bool
565
has_composite_implicit_autograd_nested_tensor_kernel: bool
566
has_composite_explicit_autograd_kernel: bool
567
has_composite_explicit_autograd_non_functional_kernel: bool
580
ei: dict[str, object],
582
valid_tags: set[str],
583
ignore_keys: set[DispatchKey] | None = None,
584
) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]:
586
Parse a NativeFunction from a dictionary as directly parsed
587
from native_functions.yaml
591
funcs = e.pop("func")
592
assert isinstance(funcs, str), f"not a str: {funcs}"
594
namespace_helper = NamespaceHelper.from_namespaced_entity(
595
namespaced_entity=funcs, max_level=1
597
namespace = namespace_helper.get_cpp_namespace(default="aten")
598
func = FunctionSchema.parse(namespace_helper.entity_name)
600
cpp_no_default_args_list = e.pop("cpp_no_default_args", [])
601
assert isinstance(cpp_no_default_args_list, list)
602
cpp_no_default_args = set(cpp_no_default_args_list)
604
use_const_ref_for_mutable_tensors = e.pop(
605
"use_const_ref_for_mutable_tensors", False
607
assert isinstance(use_const_ref_for_mutable_tensors, bool)
609
variants_s = e.pop("variants", "function")
610
assert isinstance(variants_s, str)
611
variants: set[Variant] = set()
612
for v in variants_s.split(", "):
614
variants.add(Variant.function)
616
variants.add(Variant.method)
618
raise AssertionError(f"illegal variant {v}")
620
manual_kernel_registration = e.pop("manual_kernel_registration", False)
622
manual_kernel_registration, bool
623
), f"not a bool: {manual_kernel_registration}"
625
manual_cpp_binding = e.pop("manual_cpp_binding", False)
626
assert isinstance(manual_cpp_binding, bool), f"not a bool: {manual_cpp_binding}"
628
device_guard = e.pop("device_guard", True)
629
assert isinstance(device_guard, bool), f"not a bool: {device_guard}"
631
device_check_s = e.pop("device_check", None)
632
assert device_check_s is None or isinstance(
634
), f"not a str: {device_check_s}"
636
device_check_s is None or device_check_s in DeviceCheckType.__members__
637
), f"illegal device_check: {device_check_s}"
638
device_check: DeviceCheckType
639
if device_check_s is None:
640
device_check = DeviceCheckType.ExactSame
642
device_check = DeviceCheckType[device_check_s]
644
structured = e.pop("structured", False)
645
assert isinstance(structured, bool), f"not a bool: {structured}"
647
structured_delegate_s = e.pop("structured_delegate", None)
648
assert structured_delegate_s is None or isinstance(
649
structured_delegate_s, str
650
), f"not a str: {structured_delegate_s}"
651
assert structured_delegate_s is None or "::" not in structured_delegate_s, (
652
"namespace is not supported in structured delegate,"
653
" using the same namespace as the native function"
655
structured_delegate: OperatorName | None = None
656
if structured_delegate_s is not None:
657
structured_delegate = OperatorName.parse(structured_delegate_s)
659
structured_inherits = e.pop("structured_inherits", None)
660
assert structured_inherits is None or isinstance(
661
structured_inherits, str
662
), f"not a str: {structured_inherits}"
663
assert structured_inherits is None or "::" not in structured_inherits, (
664
"namespace is not supported in structured inherits,"
665
" using the same namespace as the native function"
668
python_module = e.pop("python_module", None)
669
assert python_module is None or isinstance(
671
), f"not a str: {python_module}"
673
python_module is None or Variant.method not in variants
674
), "functions in modules cannot be methods"
676
category_override = e.pop("category_override", None)
677
assert category_override is None or isinstance(
678
category_override, str
679
), f"not a str: {category_override}"
681
precomputed_dict = e.pop("precomputed", None)
682
assert precomputed_dict is None or structured is True
683
precomputed = Precompute.parse(precomputed_dict) if precomputed_dict else None
685
tags_inp = e.pop("tags", [])
686
if isinstance(tags_inp, str):
687
tags_inp = [tags_inp]
688
assert isinstance(tags_inp, list)
691
if namespace == "aten" and "pt2_compliant_tag" in valid_tags:
692
tags_inp.append("pt2_compliant_tag")
694
tags: set[str] = set()
696
assert len(valid_tags) > 0
701
raise AssertionError(f"illegal tag {t}")
703
from torchgen.api import cpp
705
raw_dispatch = e.pop("dispatch", None)
706
assert raw_dispatch is None or isinstance(raw_dispatch, dict), e
707
dispatch: dict[DispatchKey, BackendMetadata] = {}
708
num_dispatch_keys: int = 0
709
if raw_dispatch is not None:
710
assert not manual_kernel_registration, (
711
"cannot specify both manual_kernel_registration and dispatch; with "
712
"manual registration, dispatch has no effect!"
714
redundant_composite_implicit_autograd = False
715
for ks, v in raw_dispatch.items():
720
), f"illegal dispatch key '{ks}' in {raw_dispatch}"
723
), f"illegal dispatch value '{v}' in {raw_dispatch}"
724
for k in ks.split(","):
725
dispatch_key = DispatchKey.parse(k.strip())
726
num_dispatch_keys += 1
728
if ignore_keys and dispatch_key in ignore_keys:
730
assert dispatch_key in dispatch_keys, (
731
f"Dispatch key {dispatch_key} of kernel {v} "
732
"is not a supported dispatch key."
736
namespace_helper = NamespaceHelper.from_namespaced_entity(
739
kernel_namespace = namespace_helper.get_cpp_namespace(default="at")
743
dispatch[dispatch_key] = BackendMetadata(
744
kernel=namespace_helper.entity_name,
745
structured=structured
746
and is_structured_dispatch_key(dispatch_key),
747
cpp_namespace=(kernel_namespace + "::native"),
750
dispatch_key is DispatchKey.CompositeImplicitAutograd
751
and v == cpp.name(func)
753
redundant_composite_implicit_autograd = True
759
num_dispatch_keys == 1 and redundant_composite_implicit_autograd
761
"unnecessary dispatch table for this function; just delete the dispatch "
768
or dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
769
or dispatch[DispatchKey.CompositeImplicitAutograd].supports_symint()
770
or num_dispatch_keys != 1
772
f"unexpected name for singleton CompositeImplicitAutograd dispatch entry: expected {cpp.name(func)} "
773
f"but got {dispatch[DispatchKey.CompositeImplicitAutograd]}. Rename your implementation to the expected "
774
"name, then delete the dispatch table"
776
elif not structured and structured_delegate is None:
777
name = str(func.name.name)
779
name.startswith("new_")
780
or name.endswith("_like")
783
func.arguments.tensor_options
784
and not func.arguments.has_tensor_arg()
787
f"expected {name} to have a CompositeExplicitAutograd "
788
"dispatch entry, but there was no dispatch table. Factory functions "
789
"should not have implicit dispatch as they should not be decomposed "
790
"for __torch_dispatch__"
792
dispatch[DispatchKey.CompositeImplicitAutograd] = BackendMetadata(
793
cpp.name(func), structured=False, cpp_namespace=DEFAULT_KERNEL_NAMESPACE
796
composites_in_dispatch = [
799
if d == DispatchKey.CompositeExplicitAutograd
800
or d == DispatchKey.CompositeExplicitAutogradNonFunctional
801
or d == DispatchKey.CompositeImplicitAutograd
802
or d == DispatchKey.CompositeImplicitAutogradNestedTensor
805
assert len(composites_in_dispatch) <= 1 or (
806
len(composites_in_dispatch) == 2
808
DispatchKey.CompositeExplicitAutogradNonFunctional
809
not in composites_in_dispatch
812
DispatchKey.CompositeImplicitAutogradNestedTensor
813
in composites_in_dispatch
816
"cannot specify more than one of CompositeExplicitAutograd, CompositeExplicitAutogradNonFunctional, "
817
"or CompositeImplicitAutograd on a single kernel; each "
818
"strictly subsumes the other. If you wanted to provide an explicit autograd "
819
"implementation, specify CompositeExplicitAutograd; otherwise specify CompositeImplicitAutograd only"
822
autogen_str = e.pop("autogen", "")
823
assert isinstance(autogen_str, str)
827
else [OperatorName.parse(x) for x in autogen_str.split(", ")]
830
raw_ufunc_inner_loop = e.pop("ufunc_inner_loop", {})
831
ufunc_inner_loop = {}
832
if isinstance(raw_ufunc_inner_loop, str):
833
ufunc_inner_loop[UfuncKey.Generic] = UfuncInnerLoop.parse(
834
raw_ufunc_inner_loop, UfuncKey.Generic
836
elif isinstance(raw_ufunc_inner_loop, dict):
837
for k, vo in raw_ufunc_inner_loop.items():
840
assert isinstance(k, str), f"ufunc_inner_loop key is not a str: {k}"
841
assert isinstance(vo, str), f"ufunc_inner_loop value is not a str: {v}"
842
ufunc_key = UfuncKey.parse(k)
843
ufunc_inner_loop[ufunc_key] = UfuncInnerLoop.parse(vo, ufunc_key)
845
raise AssertionError(
846
f"ufunc_inner_loop not str or dict: {raw_ufunc_inner_loop}"
850
assert structured, "ufunc must be structured"
854
import torchgen.api.ufunc as ufunc
856
for dispatch_key in UFUNC_DISPATCH_KEYS:
858
dispatch_key not in dispatch
859
), f"ufunc should not have explicit dispatch entry for {dispatch_key}"
860
dispatch[dispatch_key] = BackendMetadata(
861
kernel=ufunc.schema_kernel_name(func, dispatch_key),
863
cpp_namespace=DEFAULT_KERNEL_NAMESPACE,
866
if structured_delegate:
871
dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
873
!= {DispatchKey.CompositeImplicitAutogradNestedTensor}
876
DispatchKey.CompositeImplicitAutograd,
877
DispatchKey.CompositeImplicitAutogradNestedTensor,
881
has_composite_implicit_autograd_kernel = (
882
DispatchKey.CompositeImplicitAutograd in dispatch
884
has_composite_implicit_autograd_nested_tensor_kernel = (
885
DispatchKey.CompositeImplicitAutogradNestedTensor in dispatch
887
has_composite_explicit_autograd_kernel = (
888
DispatchKey.CompositeExplicitAutograd in dispatch
890
has_composite_explicit_autograd_non_functional_kernel = (
891
DispatchKey.CompositeExplicitAutogradNonFunctional in dispatch
898
backend_metadata = {k: {func.name: v} for k, v in dispatch.items()}
902
e.pop("__line__", None)
903
assert not e, f"leftover entries: {e}"
906
if structured_delegate is not None:
907
for key in STRUCTURED_DISPATCH_KEYS:
908
assert key not in dispatch, (
909
f"if structured_delegate, then must not have {key} in dispatch dictionary "
916
use_const_ref_for_mutable_tensors=use_const_ref_for_mutable_tensors,
918
structured=structured,
919
structured_delegate=structured_delegate,
920
structured_inherits=structured_inherits,
921
precomputed=precomputed,
923
ufunc_inner_loop=ufunc_inner_loop,
924
manual_kernel_registration=manual_kernel_registration,
925
manual_cpp_binding=manual_cpp_binding,
926
python_module=python_module,
927
category_override=category_override,
928
device_guard=device_guard,
929
device_check=device_check,
931
cpp_no_default_args=cpp_no_default_args,
932
is_abstract=is_abstract,
933
has_composite_implicit_autograd_kernel=has_composite_implicit_autograd_kernel,
934
has_composite_implicit_autograd_nested_tensor_kernel=has_composite_implicit_autograd_nested_tensor_kernel,
935
has_composite_explicit_autograd_kernel=has_composite_explicit_autograd_kernel,
936
has_composite_explicit_autograd_non_functional_kernel=has_composite_explicit_autograd_non_functional_kernel,
943
def validate_unstructured(self) -> None:
946
assert not self.structured, (
947
"This function is structured, but there was "
948
"no valid functional variant of it."
950
assert self.structured_delegate, (
951
"This function delegates to another structured out function, "
952
"but no valid function was found (the delegate may not exist, or it has the wrong type)"
962
def __post_init__(self) -> None:
963
if self.func.arguments.out:
964
assert self.variants == {Variant.function}, (
965
"Native functions with out arguments MUST "
966
"be declared with only function variant; e.g., variants: function; "
967
"otherwise you will tickle a Python argument binding bug "
968
"(which usually manifests itself as the result variable being undefined.)"
971
assert self.func.kind() == SchemaKind.out, (
972
"Put structured field on the out= "
973
"variant of a function; did you mean structured_delegate?"
977
), "device_guard: False is not respected by structured kernels"
978
if self.structured_delegate:
979
assert self.func.kind() != SchemaKind.out, (
980
"structured_delegate field not allowed "
981
"on out= functions; did you mean structured?"
985
), "device_guard: False is not respected by structured kernels"
989
self.structured and self.structured_delegate
990
), "Cannot have both structured and structured_delegate on function"
991
defaulted_arguments = {
992
a.name for a in self.func.schema_order_arguments() if a.default is not None
994
invalid_args = set.difference(self.cpp_no_default_args, defaulted_arguments)
995
assert len(invalid_args) == 0, f"Invalid cpp_no_default_args: {invalid_args}"
996
if self.structured_inherits is not None:
999
), "structured_inherits must also imply structured: True"
1000
if str(self.func.name).startswith("_foreach"):
1001
assert self.device_check == DeviceCheckType.NoCheck, (
1002
"foreach kernels fall back to slow path when tensor are on different devices, "
1003
"device_check not allowed to be enabled"
1009
"rand" in str(self.func.name)
1012
"dropout" in str(self.func.name)
1014
"dropout" in arg.name for arg in self.func.arguments.flat_all
1018
and "backward" not in str(self.func.name)
1019
and str(self.func.name.name) not in ["_cudnn_init_dropout_state"]
1021
or self.func.arguments.has_generator_arg()
1023
assert "nondeterministic_seeded" in self.tags, str(self.func.name)
1026
def has_composite_kernel(self) -> bool:
1028
self.has_composite_implicit_autograd_kernel
1029
or self.has_composite_explicit_autograd_kernel
1030
or self.has_composite_explicit_autograd_non_functional_kernel
1032
self.has_composite_implicit_autograd_kernel
1033
and self.has_composite_implicit_autograd_nested_tensor_kernel
1037
def is_view_op(self) -> bool:
1038
rets = self.func.returns
1039
is_non_mutating_view = len(rets) > 0 and any(
1040
r.annotation is not None and not r.annotation.is_write for r in rets
1044
"inplace_view" in self.tags
1045
and str(self.func.name) != "resize_"
1046
and str(self.func.name) != "resize_as_"
1048
is_wildcard_view = any(
1049
inp.annotation is not None and "*" in inp.annotation.alias_set_after
1050
for inp in self.func.schema_order_arguments()
1052
return is_non_mutating_view or is_inplace_view or is_wildcard_view
1055
def view_schema_kind(self) -> ViewSchemaKind:
1056
if self.is_view_op and self.func.name.name.inplace:
1057
assert "inplace_view" in self.tags
1058
return ViewSchemaKind.aliasing_inplace
1060
return ViewSchemaKind.aliasing
1062
return ViewSchemaKind.non_aliasing
1065
def root_name(self) -> str:
1066
return self.func.name.name.base
1069
def part_of_structured_group(self) -> bool:
1070
return self.structured or self.structured_delegate is not None
1073
class SchemaKind(Enum):
1087
@dataclass(frozen=True)
1088
class NativeFunctionsGroup:
1089
functional: NativeFunction
1090
inplace: NativeFunction | None
1091
mutable: NativeFunction | None
1095
def structured(self) -> bool:
1097
return self.out.structured
1099
def __post_init__(self) -> None:
1100
test_sig: FunctionSchema = self.functional.func.signature()
1101
for f in self.functions():
1102
if test_sig != f.func.signature():
1103
raise AssertionError(
1104
"NativeFunctionsGroup constructed from two NativeFunctions "
1105
f"that don't have matching signatures: {test_sig} != {f.func.signature()}"
1108
if self.structured != f.part_of_structured_group:
1109
raise AssertionError(
1110
"NativeFunctionsGroup constructed from structured and unstructured "
1111
f"functions: {self.out.func.name} and {f.func.name}"
1113
assert self.functional.func.kind() == SchemaKind.functional
1114
assert self.out.func.kind() == SchemaKind.out
1115
assert self.functional.namespace == self.out.namespace
1116
if self.inplace is not None:
1117
assert self.inplace.func.kind() == SchemaKind.inplace
1118
assert self.inplace.namespace == self.functional.namespace
1120
if self.mutable is not None:
1121
assert self.mutable.func.kind() == SchemaKind.mutable
1122
assert self.mutable.namespace == self.functional.namespace
1124
assert self.functional.func.name.name.functional_overload
1130
not self.out.has_composite_implicit_autograd_kernel
1131
and not self.out.has_composite_implicit_autograd_nested_tensor_kernel
1134
assert self.functional.structured_delegate == self.out.func.name, (
1135
f"{self.functional.func.name} delegates to {self.functional.structured_delegate} "
1136
f"but its actual delegate is {self.out.func.name}"
1138
if self.inplace is not None:
1139
assert self.inplace.structured_delegate == self.out.func.name
1141
generated_fns = sorted(
1142
[str(f.func.name) for f in self.functions() if "generated" in f.tags]
1144
generated_fns_str = ", ".join(str(x) for x in generated_fns)
1145
expected_generated_fns: set[str] = set()
1146
for f in self.functions():
1147
expected_generated_fns.update(str(op) for op in f.autogen)
1148
expected_generated_fns_str = ", ".join(
1149
str(x) for x in sorted(expected_generated_fns)
1151
if len(expected_generated_fns) == 0 and len(generated_fns) > 0:
1153
f"The codegen expects to be able to generate '{generated_fns_str}'."
1154
" In order to generate them however, we expect them to be called out explicitly in the yaml."
1155
f" Please add an 'autogen: {generated_fns_str}' line to the entry for {str(f.func.name)}"
1157
if expected_generated_fns_str != generated_fns_str:
1159
f"The codegen expects to be able to generate '{generated_fns_str}'."
1160
f" To do so, it expects a line: 'autogen: {generated_fns_str}'."
1161
f" Instead, it found 'autogen: {expected_generated_fns_str}'"
1164
def signature(self) -> FunctionSchema:
1165
return self.out.func.signature()
1167
def functions(self) -> Iterator[NativeFunction]:
1168
yield self.functional
1170
if self.inplace is not None:
1172
if self.mutable is not None:
1176
def root_name(self) -> str:
1177
return self.functional.root_name
1180
def from_dict(d: dict[SchemaKind, NativeFunction]) -> NativeFunctionsGroup | None:
1185
functional = d.pop(SchemaKind.functional, None)
1186
inplace = d.pop(SchemaKind.inplace, None)
1187
mutable = d.pop(SchemaKind.mutable, None)
1188
out = d.pop(SchemaKind.out, None)
1190
assert functional is not None
1196
return NativeFunctionsGroup(
1197
functional=functional,
1204
@dataclass(frozen=True)
1205
class BackendMetadata:
1223
def supports_symint(self) -> bool:
1224
return "_symint" in self.kernel
1227
@dataclass(frozen=True)
1228
class UfuncInnerLoop:
1230
supported_dtypes: OrderedSet[ScalarType]
1236
def parse(value: str, ufunc_key: UfuncKey) -> UfuncInnerLoop:
1237
name, supported_dtypes_str = value.split(" ", 1)
1238
assert supported_dtypes_str[0] == "("
1239
assert supported_dtypes_str[-1] == ")"
1240
supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
1241
for k in supported_dtypes_str[1:-1].split(", "):
1242
supported_dtypes |= ScalarType.parse_set(k)
1243
return UfuncInnerLoop(
1244
name=name, supported_dtypes=supported_dtypes, ufunc_key=ufunc_key
1255
@dataclass(frozen=True)
1257
dispatch_key: DispatchKey
1260
use_out_as_primary: bool
1268
index: dict[OperatorName, BackendMetadata]
1272
parent_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
1273
child_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
1275
for k, v in child_index.items():
1276
for op_name, metadata in v.items():
1278
op_name not in parent_index[k]
1279
), f"duplicate operator {op_name} for dispatch key {k}"
1280
parent_index[k][op_name] = metadata
1282
def primary(self, g: NativeFunctionsGroup) -> NativeFunction:
1283
if self.use_out_as_primary:
1288
def has_kernel(self, g: NativeFunction | NativeFunctionsGroup) -> bool:
1289
m = self.get_kernel(g)
1290
return m is not None
1293
self, g: NativeFunction | NativeFunctionsGroup
1294
) -> BackendMetadata | None:
1295
if isinstance(g, NativeFunction):
1297
elif isinstance(g, NativeFunctionsGroup):
1301
if f.func.name not in self.index:
1303
return self.index[f.func.name]
1305
def native_function_class_name(self) -> str | None:
1307
return f"{str(self.dispatch_key)}NativeFunctions"
1368
@dataclass(frozen=True)
1369
class FunctionSchema:
1373
arguments: Arguments
1376
returns: tuple[Return, ...]
1379
def is_mutable(self) -> bool:
1380
def is_write(arg: Argument) -> bool:
1381
if arg.annotation is None:
1383
return arg.annotation.is_write
1387
return any(is_write(a) for a in self.arguments.flat_all)
1389
def schema_order_arguments(self) -> Iterator[Argument]:
1390
return itertools.chain(
1391
self.arguments.flat_positional,
1392
self.arguments.flat_kwarg_only,
1396
decl_re = re.compile(r"(?P<name>[^\(]+)\((?P<args>.*)\) -> (?P<returns>.*)")
1399
def parse(func: str) -> FunctionSchema:
1401
decls = FunctionSchema.decl_re.findall(func)
1402
assert len(decls) == 1, f"Invalid function schema: {func}"
1403
ops, args, return_decl = decls[0]
1404
name = OperatorName.parse(ops)
1405
arguments = Arguments.parse(args)
1406
returns = parse_returns(return_decl)
1407
r = FunctionSchema(name=name, arguments=arguments, returns=returns)
1408
assert str(r) == func, f"{str(r)} != {func}"
1411
def returns_are_aliased(self) -> bool:
1415
for r in self.returns
1416
if r.annotation is not None and r.annotation.is_write
1419
def __post_init__(self) -> None:
1420
for arg, ret in zip(self.arguments.out, self.returns):
1421
assert arg.annotation == ret.annotation, (
1422
"Out arguments must have matching return Tensor; furthermore, "
1423
"the ith-argument needs to correspond to the ith return"
1427
for a in self.arguments.post_self_positional_mutable:
1429
a.annotation == r.annotation for r in self.returns
1430
), f"If you have a schema with mutable positional args, we expect them to not be returned. schema: {str(self)}"
1435
out_and_self = list(self.arguments.out) + [
1436
arg for arg in self.arguments.flat_positional if arg.name == "self"
1440
for ret in self.returns
1441
if ret.annotation is not None and ret.annotation.is_write
1443
immutable_returns = [
1445
for ret in self.returns
1446
if ret.annotation is None or not ret.annotation.is_write
1454
len(mutable_returns) == 0 or len(immutable_returns) == 0
1455
), f"NativeFunctions must have either only mutable returns, or only immutable returns. Found: {str(self)}"
1456
for ret in mutable_returns:
1457
assert any(ret.annotation == arg.annotation for arg in out_and_self), (
1458
'All mutable returns must be aliased either to a keyword argument, or to "self". '
1459
"Did you forget to mark an out argument as keyword-only?"
1461
if self.arguments.out:
1467
if any(a.type != BaseType(BaseTy.Tensor) for a in self.arguments.out):
1469
len(self.returns) == 0
1470
), "out= ops that accept tensor lists as out arguments "
1471
"are expected to have no return type (since you can't do method chaining on them)"
1478
for arg in self.arguments.out
1479
if not arg.name.startswith("_scratch_")
1483
), "Must return as many arguments as there are out arguments, or no return at all"
1485
if self.name.name.inplace:
1486
self_a = self.arguments.self_arg
1489
and self_a.argument.annotation
1490
and self_a.argument.annotation.is_write
1492
if self_a.argument.type == BaseType(BaseTy.Tensor):
1496
len(self.returns) == 1
1497
and self.returns[0].annotation == self_a.argument.annotation
1502
assert len(self.returns) == 0
1504
if self.arguments.tensor_options is not None:
1505
assert self.kind() == SchemaKind.functional, (
1506
"Found an operator that is not functional or out variant, but has tensor options arguments."
1507
"This is not allowed- tensor options arguments are only allowed for factory functions."
1508
f"schema: {str(self)}"
1510
if self.is_functional_fn():
1511
assert self.kind() == SchemaKind.functional, (
1512
"Found an operator that is not functional, but its overload contains the string 'functional'."
1513
"This is a special keyword in the codegen, please use a different overload name."
1514
f"schema: {str(self)}"
1517
def is_functional_fn(self) -> bool:
1518
return "functional" in self.name.overload_name
1520
def is_out_fn(self) -> bool:
1547
return bool(self.arguments.out)
1549
def kind(self) -> SchemaKind:
1551
What kind of schema is this? A functional schema is one
1552
that returns a newly allocated output; an inplace schema
1553
modifies the self argument inplace; an out schema writes
1554
the result into an explicitly provided out argument.
1556
is_out = bool(self.arguments.out)
1558
[arg for arg in self.arguments.out if arg.name.startswith("_scratch_")]
1560
is_inplace = self.name.name.inplace
1562
a.annotation is not None and a.annotation.is_write
1563
for a in self.arguments.post_self_positional
1565
assert not (is_out and is_inplace)
1574
return SchemaKind.inplace
1578
), "invariant: all scratch operators are expected to be out= operators too"
1579
return SchemaKind.scratch
1583
), "We should not categorize a scratch op as an out variant. Check if the order of if statements are expected!"
1584
return SchemaKind.out
1586
return SchemaKind.mutable
1588
return SchemaKind.functional
1594
def aliased_return_names(self) -> list[str | None]:
1595
outs: list[str | None] = []
1596
for r in self.returns:
1599
for a in self.arguments.flat_all
1600
if a.annotation is not None and a.annotation == r.annotation
1602
if len(aliased_args) == 0:
1604
elif len(aliased_args) == 1:
1605
outs.append(aliased_args[0].name)
1607
aliased_names = ", ".join(a.name for a in aliased_args)
1608
raise AssertionError(
1609
f"Found a return ({r.name})that aliases multiple inputs ({aliased_names})"
1616
strip_default: bool = False,
1617
strip_view_copy_name: bool = False,
1618
keep_return_names: bool = False,
1619
) -> FunctionSchema:
1621
Certain schemas are 'related', in that they are simply
1622
inplace/out/functional versions of the same function. This method
1623
factors these schemas into the "core" functional signature which
1624
is equal across all versions.
1626
Here is what normalization happens to the schema to convert
1628
- The overload name is stripped (name is retained, since
1629
it expresses semantic content about what the function does)
1630
- Inplace is set False
1631
- Out arguments are stripped
1632
- Mutable post_self_positional args are converted to returns
1633
- Mutability annotations are stripped (this is sound
1634
because you cannot overload on mutability annotation)
1635
- Return names are stripped since they are not overloadable and
1636
some variants have return names but some not
1637
- TensorOptions are dropped
1638
because out= variants of factory functions don't include them
1639
(and we want to be able to pair up factory functions with their out variants)
1641
Finally, we want to be able to pair up related "view" and their
1642
corresponding "view_copy" operators. We do this by optionally
1643
stripping the trailing "_copy" from the base name.
1645
Example of a mutable op before and after:
1647
f.func (Mutable operator):
1648
_fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) # noqa: B950
1650
f.func (Corresponding functional operator):
1651
_fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out) # noqa: B950
1653
f.func.signature() output:
1654
_fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) # noqa: B950
1657
def strip_ret_annotation(r: Return) -> Return:
1659
name=r.name if keep_return_names else None,
1664
base_name = self.name.name.base
1665
if strip_view_copy_name:
1666
if base_name.endswith("_copy"):
1667
base_name = base_name.replace("_copy", "")
1668
elif base_name.endswith("_scatter"):
1669
base_name = base_name.replace("scatter", "inverse")
1672
returns_from_mutable_inputs = tuple(
1677
name=f"{a.name}_out" if keep_return_names else None,
1681
for a in itertools.chain(
1684
[self.arguments.self_arg.argument]
1685
if self.arguments.self_arg is not None
1688
self.arguments.post_self_positional,
1690
if a.annotation is not None
1691
and a.annotation.is_write
1692
and not any(a.annotation == r.annotation for r in self.returns)
1694
original_returns = tuple(map(strip_ret_annotation, self.returns))
1696
returns = original_returns + returns_from_mutable_inputs
1698
args_sig = self.arguments.signature(strip_default=strip_default)
1700
if str(self.name) == "bernoulli.p":
1701
args_sig = Arguments.parse(str(args_sig).replace("float p", "float p=0.5"))
1703
return FunctionSchema(
1705
name=BaseOperatorName(
1708
dunder_method=self.name.name.dunder_method,
1716
def view_signature(self) -> FunctionSchema:
1717
return self.signature(strip_view_copy_name=True)
1719
def with_name(self, name: OperatorName) -> FunctionSchema:
1720
return FunctionSchema(
1722
arguments=self.arguments,
1723
returns=self.returns,
1727
def modifies_arguments(self) -> bool:
1728
return self.kind() in [SchemaKind.inplace, SchemaKind.out, SchemaKind.mutable]
1730
def has_symint(self) -> bool:
1731
return self.arguments.has_symint_arg()
1733
def __str__(self) -> str:
1734
all_arguments_str = str(self.arguments)
1735
if len(self.returns) == 1:
1736
returns = str(self.returns[0])
1738
returns = "(" + ", ".join(map(str, self.returns)) + ")"
1739
return f"{self.name}({all_arguments_str}) -> {returns}"
1750
@dataclass(frozen=True)
1754
alias_set: tuple[str, ...]
1756
alias_set_after: tuple[str, ...]
1759
def parse(ann: str) -> Annotation:
1774
m = re.match(r"^([a-z])(\|[a-z])*(!?)( -> (\*|[a-z](\|[a-z])*))?$", ann)
1776
assert m is not None, f"unrecognized alias annotation {ann}"
1777
before_alias = m.group(1) + (m.group(2) if m.group(2) else "")
1778
alias_set = tuple(before_alias.split("|"))
1779
is_write = m.group(3) == "!"
1781
is_write and len(alias_set) > 1
1782
), f"alias set larger than 1 is not mutable, got {ann} instead."
1783
after_set = tuple(m.group(5).split("|")) if m.group(5) else ()
1785
len(before_alias) > 1 and len(after_set) > 1
1786
), f"before alias set and after alias set cannot be larger than 1 at the same time, got {ann} instead."
1788
alias_set=alias_set, is_write=is_write, alias_set_after=after_set
1790
assert str(r) == ann, f"{r} != {ann}"
1793
def __str__(self) -> str:
1794
alias_set = "|".join(self.alias_set)
1796
alias_set = f"{alias_set}!"
1797
alias_set_after = "|".join(self.alias_set_after)
1799
alias_set = f'{alias_set}{" -> "}{alias_set_after}'
1809
@dataclass(frozen=True)
1812
def parse(t: str) -> Type:
1814
assert str(r) == t, f"{r} != {t}"
1818
def _parse(t: str) -> Type:
1819
m = re.match(r"^(.+)\?$", t)
1821
return OptionalType(Type.parse(m.group(1)))
1822
m = re.match(r"^(.+)\[([0-9]+)?\]$", t)
1824
size = int(m.group(2)) if m.group(2) is not None else None
1825
return ListType(elem=Type.parse(m.group(1)), size=size)
1828
m = re.match(r"^__torch__\.torch\.classes\.([a-zA-Z0-9_.]+)$", t)
1830
return CustomClassType(m.group(1))
1832
return BaseType(BaseTy[t])
1833
except KeyError as e:
1834
raise RuntimeError(f"unrecognized type {t}") from e
1836
def __str__(self) -> str:
1837
raise NotImplementedError
1844
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
1845
raise NotImplementedError
1847
def is_tensor_like(self) -> bool:
1848
return self.is_base_ty_like(BaseTy.Tensor)
1850
def is_generator_like(self) -> bool:
1851
return self.is_base_ty_like(BaseTy.Generator)
1853
def is_symint_like(self) -> bool:
1854
return self.is_base_ty_like(BaseTy.SymInt)
1856
def is_nullable(self) -> bool:
1857
raise NotImplementedError
1859
def is_list_like(self) -> ListType | None:
1860
raise NotImplementedError
1876
DeviceIndex = auto()
1878
MemoryFormat = auto()
1884
ConstQuantizerPtr = auto()
1885
GraphModule = auto()
1888
@dataclass(frozen=True)
1889
class BaseType(Type):
1892
def __str__(self) -> str:
1893
return f"{self.name.name}"
1895
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
1896
return self.name == base_ty
1898
def is_nullable(self) -> bool:
1901
def is_list_like(self) -> ListType | None:
1904
def is_symint_like(self) -> bool:
1905
return self.name == BaseTy.SymInt
1909
@dataclass(frozen=True)
1910
class OptionalType(Type):
1913
def __str__(self) -> str:
1914
return f"{self.elem}?"
1916
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
1917
return self.elem.is_base_ty_like(base_ty)
1919
def is_symint_like(self) -> bool:
1920
return self.elem.is_symint_like()
1922
def is_nullable(self) -> bool:
1925
def is_list_like(self) -> ListType | None:
1926
return self.elem.is_list_like()
1930
@dataclass(frozen=True)
1931
class CustomClassType(Type):
1934
def __str__(self) -> str:
1936
Return the class name will prefix __torch__.torch.classes
1938
return f"__torch__.torch.classes.{self.class_name}"
1940
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
1943
def is_symint_like(self) -> bool:
1946
def is_nullable(self) -> bool:
1948
Assume a custom class is not nullable.
1952
def is_list_like(self) -> ListType | None:
1963
@dataclass(frozen=True)
1964
class ListType(Type):
1968
def __str__(self) -> str:
1969
size = f"{self.size}" if self.size else ""
1970
return f"{self.elem}[{size}]"
1972
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
1973
return self.elem.is_base_ty_like(base_ty)
1975
def is_symint_like(self) -> bool:
1976
return self.elem.is_symint_like()
1978
def is_nullable(self) -> bool:
1979
return self.elem.is_nullable()
1981
def is_list_like(self) -> ListType | None:
1985
@dataclass(frozen=True)
2013
annotation: Annotation | None
2016
def alias_info(self) -> Annotation | None:
2017
return self.annotation
2020
def parse(arg: str) -> Argument:
2023
assert " " in arg, f"illegal argument '{arg}'"
2025
assert arg.count("=") == 1, f"illegal argument with default value: '{arg}'"
2026
type_and_annot_and_name, default = arg.split("=")
2027
type_and_annot, name = type_and_annot_and_name.rsplit(" ", 1)
2028
name_and_default = f"{name}={default}"
2030
type_and_annot, name_and_default = arg.rsplit(" ", 1)
2031
name = name_and_default
2034
match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
2035
annotation: Annotation | None
2038
assert match.group(2) in [
2042
], "unrecognized alias analysis form with Tensor"
2043
type_s = "Tensor" + match.group(2)
2044
annotation = Annotation.parse(match.group(1))
2046
type_s = type_and_annot
2048
type = Type.parse(type_s)
2053
annotation=annotation,
2055
assert str(r) == arg, f"{str(r)} != {arg}"
2059
def is_write(self) -> bool:
2060
return self.annotation is not None and self.annotation.is_write
2062
def __str__(self) -> str:
2063
type = f"{self.type}"
2065
assert type in ["Tensor", "Tensor?", "Tensor[]"]
2066
type = type.replace("Tensor", f"Tensor({self.annotation})")
2067
if self.name is None:
2072
mb_default = f"={self.default}"
2073
return f"{type} {self.name}{mb_default}"
2076
@dataclass(frozen=True)
2080
annotation: Annotation | None
2083
def alias_info(self) -> Annotation | None:
2084
return self.annotation
2087
def parse(arg: str) -> Return:
2090
type_and_annot, name = arg.rsplit(" ", 1)
2092
type_and_annot = arg
2094
match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
2095
annotation: Annotation | None
2098
assert match.group(2) in [
2102
], "unrecognized alias analysis form with Tensor"
2103
type_s = "Tensor" + match.group(2)
2104
annotation = Annotation.parse(match.group(1))
2106
type_s = type_and_annot
2108
type = Type.parse(type_s)
2112
annotation=annotation,
2114
assert str(r) == arg, f"{str(r)} != {arg}"
2118
def is_write(self) -> bool:
2119
return self.annotation is not None and self.annotation.is_write
2121
def __str__(self) -> str:
2122
type = f"{self.type}"
2124
assert type in ["Tensor", "Tensor?", "Tensor[]"]
2125
type = type.replace("Tensor", f"Tensor({self.annotation})")
2126
if self.name is None:
2129
return f"{type} {self.name}"
2133
@dataclass(frozen=True)
2141
@dataclass(frozen=True)
2142
class TensorOptionsArguments:
2146
pin_memory: Argument
2148
def all(self) -> Sequence[Argument]:
2149
return [self.dtype, self.layout, self.device, self.pin_memory]
2152
@dataclass(frozen=True)
2157
pre_self_positional: tuple[Argument, ...]
2158
self_arg: SelfArgument | None
2159
post_self_positional: tuple[Argument, ...]
2161
pre_tensor_options_kwarg_only: tuple[Argument, ...]
2162
tensor_options: TensorOptionsArguments | None
2166
post_tensor_options_kwarg_only: tuple[Argument, ...]
2173
out: tuple[Argument, ...]
2176
def flat_non_out(self) -> Sequence[Argument]:
2177
ret: list[Argument] = []
2178
ret.extend(self.flat_positional)
2179
ret.extend(self.flat_kwarg_only)
2183
def flat_positional(self) -> Sequence[Argument]:
2184
ret: list[Argument] = []
2185
ret.extend(self.pre_self_positional)
2186
if self.self_arg is not None:
2187
ret.append(self.self_arg.argument)
2188
ret.extend(self.post_self_positional)
2192
def post_self_positional_mutable(self) -> Sequence[Argument]:
2193
return [a for a in self.post_self_positional if a.is_write]
2197
def flat_kwarg_only(self) -> Sequence[Argument]:
2198
ret: list[Argument] = []
2199
ret.extend(self.pre_tensor_options_kwarg_only)
2200
if self.tensor_options is not None:
2201
ret.extend(self.tensor_options.all())
2202
ret.extend(self.post_tensor_options_kwarg_only)
2206
def flat_all(self) -> Sequence[Argument]:
2207
ret: list[Argument] = []
2208
ret.extend(self.flat_positional)
2209
ret.extend(self.flat_kwarg_only)
2210
ret.extend(self.out)
2216
) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]:
2217
ret: list[Argument | SelfArgument | TensorOptionsArguments] = []
2218
ret.extend(self.positional)
2219
ret.extend(self.kwarg_only)
2223
def positional(self) -> Sequence[Argument | SelfArgument]:
2224
ret: list[Argument | SelfArgument] = []
2225
ret.extend(self.pre_self_positional)
2226
if self.self_arg is not None:
2227
ret.append(self.self_arg)
2228
ret.extend(self.post_self_positional)
2232
def kwarg_only(self) -> Sequence[Argument | TensorOptionsArguments]:
2233
ret: list[Argument | TensorOptionsArguments] = []
2234
ret.extend(self.pre_tensor_options_kwarg_only)
2235
if self.tensor_options is not None:
2236
ret.append(self.tensor_options)
2237
ret.extend(self.post_tensor_options_kwarg_only)
2241
def all(self) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]:
2242
ret: list[Argument | SelfArgument | TensorOptionsArguments] = []
2243
ret.extend(self.positional)
2244
ret.extend(self.kwarg_only)
2245
ret.extend(self.out)
2248
def mutable_arg_names(self) -> list[str]:
2251
for a in self.flat_all
2252
if a.annotation is not None and a.annotation.is_write
2255
def has_tensor_arg(self) -> bool:
2256
return any(a.type.is_tensor_like() for a in self.flat_non_out)
2258
def has_symint_arg(self) -> bool:
2259
return any(a.type.is_symint_like() for a in self.flat_non_out)
2261
def has_generator_arg(self) -> bool:
2262
return any(a.type.is_generator_like() for a in self.flat_non_out)
2264
def signature(self, *, strip_default: bool = False) -> Arguments:
2267
def strip_arg_annotation(a: Argument) -> Argument:
2271
default=a.default if not strip_default else None,
2276
pre_self_positional=tuple(
2277
map(strip_arg_annotation, self.pre_self_positional)
2279
self_arg=SelfArgument(strip_arg_annotation(self.self_arg.argument))
2280
if self.self_arg is not None
2282
post_self_positional=tuple(
2283
map(strip_arg_annotation, self.post_self_positional)
2287
pre_tensor_options_kwarg_only=tuple(
2288
map(strip_arg_annotation, self.pre_tensor_options_kwarg_only)
2290
+ tuple(map(strip_arg_annotation, self.post_tensor_options_kwarg_only)),
2293
tensor_options=None,
2294
post_tensor_options_kwarg_only=(),
2299
def remove_self_annotation(self) -> Arguments:
2300
assert self.self_arg is not None
2301
return dataclasses.replace(
2303
self_arg=SelfArgument(
2304
dataclasses.replace(self.self_arg.argument, annotation=None)
2308
def with_out_args(self, outs: list[Argument]) -> Arguments:
2309
assert len(self.out) == 0
2310
return dataclasses.replace(
2316
def _preparse(args: str) -> tuple[list[Argument], list[Argument], list[Argument]]:
2317
positional: list[Argument] = []
2318
kwarg_only: list[Argument] = []
2319
out: list[Argument] = []
2320
arguments_acc = positional
2324
for arg in args.split(", "):
2329
arguments_acc is positional
2330
), "invalid syntax: kwarg-only specifier * can only occur once"
2331
arguments_acc = kwarg_only
2333
parg = Argument.parse(arg)
2340
if parg.annotation is not None and parg.annotation.is_write:
2341
if arguments_acc is positional:
2343
elif arguments_acc is kwarg_only:
2346
assert arguments_acc is not out
2347
arguments_acc.append(parg)
2349
return positional, kwarg_only, out
2352
def parse(args: str) -> Arguments:
2354
Input: 'int x, int y, int z'
2362
positional, kwarg_only, out = Arguments._preparse(args)
2366
for i, a in enumerate(positional):
2367
if a.name == "self":
2370
pre_self_positional: list[Argument]
2371
self_arg: SelfArgument | None
2372
post_self_positional: list[Argument]
2373
if self_ix is not None:
2374
pre_self_positional = positional[:self_ix]
2375
self_arg = SelfArgument(positional[self_ix])
2376
post_self_positional = positional[self_ix + 1 :]
2378
pre_self_positional = []
2380
post_self_positional = positional
2383
pre_tensor_options_kwarg_only: list[Argument] = []
2384
tensor_options: TensorOptionsArguments | None = None
2385
post_tensor_options_kwarg_only: list[Argument] = []
2386
kwarg_only_acc = pre_tensor_options_kwarg_only
2388
def pred(name: str, ty: Type) -> Callable[[Argument], bool]:
2389
return lambda a: a.name == name and a.type in [ty, OptionalType(ty)]
2392
pred("dtype", Type.parse("ScalarType")),
2393
pred("layout", Type.parse("Layout")),
2394
pred("device", Type.parse("Device")),
2395
pred("pin_memory", Type.parse("bool")),
2399
while i < len(kwarg_only):
2401
if i <= len(kwarg_only) - len(predicates):
2405
for p, a in zip(predicates, kwarg_only[i : i + len(predicates)])
2407
assert kwarg_only_acc is pre_tensor_options_kwarg_only
2409
tensor_options = TensorOptionsArguments(
2410
dtype=kwarg_only[i],
2411
layout=kwarg_only[i + 1],
2412
device=kwarg_only[i + 2],
2413
pin_memory=kwarg_only[i + 3],
2415
i += len(predicates)
2416
kwarg_only_acc = post_tensor_options_kwarg_only
2418
kwarg_only_acc.append(kwarg_only[i])
2422
pre_self_positional=tuple(pre_self_positional),
2424
post_self_positional=tuple(post_self_positional),
2425
pre_tensor_options_kwarg_only=tuple(pre_tensor_options_kwarg_only),
2426
tensor_options=tensor_options,
2427
post_tensor_options_kwarg_only=tuple(post_tensor_options_kwarg_only),
2431
def __str__(self) -> str:
2432
all_arguments: list[str] = []
2433
all_arguments.extend(map(str, self.flat_positional))
2434
if self.flat_kwarg_only or self.out:
2435
all_arguments.append("*")
2436
all_arguments.extend(map(str, self.flat_kwarg_only))
2437
all_arguments.extend(map(str, self.out))
2438
return ", ".join(all_arguments)
2440
def __post_init__(self) -> None:
2443
if self.self_arg is None:
2444
assert not self.pre_self_positional
2445
if self.tensor_options is None:
2446
assert not self.post_tensor_options_kwarg_only
2450
mutable_pre_self_positionals = [
2452
for a in self.pre_self_positional
2453
if a.annotation is not None and a.annotation.is_write
2456
len(mutable_pre_self_positionals) == 0
2457
), "mutable pre_self_positional arguments are not currently supported in the schema"
2463
AUGMENTED_ASSIGNMENT_NAMES = [
2484
@dataclass(frozen=True)
2485
class BaseOperatorName:
2508
functional_overload: bool = False
2511
def parse(op: str) -> BaseOperatorName:
2513
assert not op.endswith("_out"), (
2514
"_out suffix is reserved and not permitted for operator names; "
2515
"did you mean to specify an out overload name instead?"
2517
m = re.match(r"^__([^_]+)__$", op)
2519
dunder_method = True
2521
if any(base == f"i{n}" for n in AUGMENTED_ASSIGNMENT_NAMES):
2530
assert base[0] != "i"
2532
dunder_method = False
2541
functional_suffix = "_functional"
2542
if base.endswith(functional_suffix):
2543
functional_overload = True
2544
base = base[: -len(functional_suffix)]
2547
assert not dunder_method and not inplace
2549
functional_overload = False
2551
r = BaseOperatorName(
2554
dunder_method=dunder_method,
2555
functional_overload=functional_overload,
2557
assert str(r) == op, f"{str(r)} != {op}"
2560
def __str__(self) -> str:
2561
if self.dunder_method:
2562
i = "i" if self.inplace else ""
2563
return f"__{i}{self.base}__"
2569
if self.functional_overload
2572
return f"{self.base}{i}"
2577
@dataclass(frozen=True)
2579
name: BaseOperatorName
2583
def parse(op_name: str) -> OperatorName:
2585
name, overload_name = op_name.split(".", 1)
2589
r = OperatorName(name=BaseOperatorName.parse(name), overload_name=overload_name)
2590
assert str(r) == op_name, f"{str(r)} != {op_name}"
2593
def __str__(self) -> str:
2594
if self.overload_name:
2595
return f"{self.name}.{self.overload_name}"
2597
return f"{self.name}"
2604
def unambiguous_name(self) -> str:
2605
if self.overload_name:
2606
return f"{self.name}_{self.overload_name}"
2608
return f"{self.name}"
2610
def remove_inplace(self) -> OperatorName:
2611
return OperatorName(
2612
name=BaseOperatorName(
2613
base=self.name.base,
2615
dunder_method=self.name.dunder_method,
2617
overload_name=self.overload_name,
2620
def with_overload(self, overload: str) -> OperatorName:
2621
return OperatorName(
2622
name=BaseOperatorName(
2623
base=self.name.base,
2625
dunder_method=self.name.dunder_method,
2627
overload_name=overload,
2631
def gets_generated_out_inplace_wrapper(
2632
f: NativeFunction, g: NativeFunctionsGroup, b: BackendIndex
2635
f.func.kind() is not SchemaKind.functional
2636
and not b.has_kernel(f)
2637
and b.has_kernel(g.functional)
2652
@dataclass(frozen=True)
2653
class NativeFunctionsViewGroup:
2654
view: NativeFunction
2658
view_copy: NativeFunction | None
2660
view_inplace: NativeFunction | None
2662
def __post_init__(self) -> None:
2663
assert self.view.is_view_op
2664
if self.view_copy is None:
2665
assert not gets_generated_view_copy(self.view), (
2666
f"{str(self.view.func.name)} appears to be a new operator that aliases its inputs."
2667
" The codegen expects you to add a corresponding operator to native_functions.yaml:"
2668
f" {get_view_copy_name(self.view)!s}."
2669
" See Note [view_copy NativeFunctions] for details."
2672
assert self.view_copy.func.name.name.base.endswith(("_copy", "_scatter"))
2673
assert self.view.func.signature() == self.view_copy.func.signature(
2674
strip_view_copy_name=True,
2676
assert "view_copy" in self.view_copy.tags, (
2677
f"{str(self.view_copy.func.name), str(self.view.tags)} appears to be a view_copy operator. The codegen expects"
2678
" view_copy operators to be annotated with the 'view_copy' tag in native_functions.yaml."
2679
" See Note [view_copy NativeFunction] for details."
2681
if self.view_inplace is not None:
2682
assert self.view.func.signature() == self.view_inplace.func.signature()
2684
if self.view.has_composite_implicit_autograd_kernel:
2685
if self.view_inplace is not None:
2686
assert self.view_inplace.has_composite_implicit_autograd_kernel, (
2687
f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
2688
" both have CompositeImplicitAutograd kernels, or both not have composite kernels."
2690
if self.view.has_composite_implicit_autograd_nested_tensor_kernel:
2691
if self.view_inplace is not None:
2693
self.view_inplace.has_composite_implicit_autograd_nested_tensor_kernel
2695
f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
2696
" both have CompositeImplicitAutogradNestedTensor kernels, or both not have composite kernels."
2699
def functions(self, *, include_copy: bool = True) -> Iterator[NativeFunction]:
2701
if self.view_inplace is not None:
2702
yield self.view_inplace
2703
if self.view_copy is not None and include_copy:
2704
yield self.view_copy
2707
def root_name(self) -> str:
2708
return self.view.root_name
2711
def composite(self) -> bool:
2714
return self.view.has_composite_implicit_autograd_kernel
2717
def gets_generated_view_copy(f: NativeFunction) -> bool:
2719
if not f.is_view_op:
2723
if f.has_composite_implicit_autograd_kernel:
2726
if "inplace_view" in f.tags:
2733
if f.func.name.name.base.endswith("_inverse"):
2740
def get_view_copy_name(f: NativeFunction) -> OperatorName:
2747
list_of_ops_with_explicit_view_copy_operators = ["narrow"]
2748
if str(f.func.name) not in list_of_ops_with_explicit_view_copy_operators:
2749
assert gets_generated_view_copy(f)
2751
base_name = f"{f.func.name.name.base}_copy"
2752
view_copy_name = OperatorName(
2753
name=BaseOperatorName(
2754
base=base_name, inplace=False, dunder_method=f.func.name.name.dunder_method
2756
overload_name=f.func.name.overload_name,
2758
return view_copy_name
2764
def parse_returns(return_decl: str) -> tuple[Return, ...]:
2769
if return_decl == "()":
2771
if return_decl[0] == "(" and return_decl[-1] == ")":
2772
return_decl = return_decl[1:-1]
2773
return tuple(Return.parse(arg) for arg in return_decl.split(", "))
2779
@dataclass(frozen=True)
2783
replace: dict[str, list[Argument]]
2788
def parse(src: object) -> Precompute:
2789
assert isinstance(src, list)
2799
if " -> " not in src[-1]:
2800
add_list = src[-1].split(",")
2801
add_args = [Argument.parse(name.strip()) for name in add_list]
2805
for raw_replace_item in src:
2806
assert isinstance(raw_replace_item, str)
2807
assert " -> " in raw_replace_item, (
2808
"precomputed parameters without replacement"
2809
" are allowed only in the last line"
2812
arg, with_list_raw = raw_replace_item.split(" -> ")
2815
), f"illegal kernel param name '{arg}' in precomputed parameters'"
2816
with_list = with_list_raw.split(",")
2817
with_list_args = [Argument.parse(name.strip()) for name in with_list]
2818
replace[arg] = with_list_args
2820
r = Precompute(replace=replace, add=add_args)
2821
assert r.to_list() == src, "r.to_list() != src"
2824
def __post_init__(self) -> None:
2828
assert a.name.upper() != a.name
2829
for args in self.replace.values():
2831
assert a.name.upper() != a.name
2833
def to_list(self) -> list[str]:
2835
for kernel_param, replacement_params in self.replace.items():
2836
replacements = ", ".join(str(param) for param in replacement_params)
2837
replace_list.append(f"{kernel_param} -> {replacements}")