1
from __future__ import annotations
3
from dataclasses import dataclass
4
from typing import Sequence
6
from torchgen.api import cpp
7
from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
8
from torchgen.gen import pythonify_default
9
from torchgen.model import (
23
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
27
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
29
# [Notes] python binding codegen
31
# The Python binding codegen produces code that takes the input list of
32
# PyObjects, finds the matching ATen C++ function using PythonArgParser,
33
# converts the PyObjects into C++ types and calls the ATen C++ function:
35
# +--------+ parsing +------------------------+ binding +-----------------------+
36
# | PyObjs | ---------> | PythonArgParser Output | ---------> | Cpp Function Dispatch |
37
# +--------+ +------------------------+ +-----------------------+
39
# The following examples demonstrate the data models the Python binding
40
# codegen needs to deal with and the tasks it needs to accomplish. It
41
# helps understand the purpose of the new data types we introduced below.
43
# - Function Schema (source of truth)
45
# aten::empty.names(int[] size, *, Dimname[]? names,
46
# ScalarType? dtype=None, Layout? layout=None,
47
# Device? device=None, bool? pin_memory=None,
48
# MemoryFormat? memory_format=None) -> Tensor
52
# It's used to generate input schema string for PythonArgParser.
53
# Note: TensorOptions fields are reordered and the additional
54
# 'requires_grad' field is added:
56
# empty(IntArrayRef size, *, DimnameList? names,
57
# MemoryFormat? memory_format=None, ScalarType dtype=None,
58
# Layout layout=torch.strided, Device device=None,
59
# bool pin_memory=False, bool requires_grad=False)
63
# It's used to generate C++ lambda formals & dispatch call.
64
# Note: the scattered TensorOptions fields are packed into 'options'.
66
# auto dispatch_empty =
67
# [](IntArrayRef size, std::optional<DimnameList> names,
68
# const TensorOptions & options,
69
# std::optional<MemoryFormat> memory_format) -> Tensor {
70
# pybind11::gil_scoped_release no_gil;
71
# return torch::empty(size, names, options, memory_format);
74
# - Binding between Python Arguments and C++ Arguments
76
# Given a set of Python Arguments in scope, we need produce the
77
# binding expressions that translate the Python API into C++ API:
79
# Python Args Cpp Args Binding Exprs
80
# -----------------------------------------------------------------
81
# 0: size size '_r.intlist(0)'
82
# 1: names names 'names' [special init]
83
# 2: memory_format -------+
84
# 3: dtype -----+-|--> options 'options' [special packing]
86
# 5: device / +--> memory_format '_r.memoryformatOptional(2)'
90
# So the full dispatch expression would look like:
92
# dispatch_empty(_r.intlist(0), names, options,
93
# _r.memoryformatOptional(2))
95
# Where does 'names' come from? It involves special local init:
97
# auto __names = _r.toDimnameListOptional(1);
98
# std::optional<DimnameList> names =
99
# __names ? std::make_optional(DimnameList(__names.value()))
102
# Where does 'options' come from? It involves special local init
103
# for TensorOptions. Note that Python side has the additional
104
# 'requires_grad' field:
106
# const auto options = TensorOptions()
107
# .dtype(_r.scalartype(3))
108
# .device(_r.device(5))
109
# .layout(_r.layoutOptional(4))
110
# .requires_grad(_r.toBool(7))
111
# .pinned_memory(_r.toBool(6));
113
# In some other cases one Python Argument can map to multiple C++
114
# Arguments. For example:
116
# aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False)
117
# -> (Tensor values, Tensor indices)
119
# Python Args Cpp Args Binding Exprs
120
# ---------------------------------------------------------------------
122
# /-----> max_values 'out[1]
123
# 0: input / self '_r.tensor(0)'
124
# 1: dim / dim '_r.dimname(1)'
125
# 2: keepdim / keepdim '_r.toBool(2)'
126
# 3: out -----+ [local init] out '_r.tensorlist_n<2>(3)'
128
# As demonstrated above, the binding can involve reordering,
129
# packing, unpacking and special local inits.
132
# Let's look at a concrete example:
134
# static PythonArgParser parser({
135
# "abs(Tensor input, *, Tensor out=None)",
136
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
138
# +--- Python Schema, represented by PythonSignature and PythonArgument
140
# }, /*traceable=*/true);
142
# ParsedArgs<2> parsed_args;
143
# auto _r = parser.parse(nullptr, args, kwargs, parsed_args);
148
# ~~~~~~~~~~~~ <--- Scattered PythonArgParser output (arg name = 'out')
149
# represented by PythonArgParserOutputExpr
151
# // aten::abs(Tensor self) -> Tensor
152
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
154
# +--- NativeFunction schema, base version
156
# auto dispatch_abs = [](const Tensor & self) -> Tensor {
157
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
159
# +--- dispatch_lambda_args / dispatch_lambda_return_str
160
# generated from NativeFunction / CppSignature
161
# (deprecated PythonSignature is special)
162
# arguments are represented by DispatchLambdaArgument
164
# pybind11::gil_scoped_release no_gil;
166
# ~~~~~~~~~~~ <--- cpp_dispatch_target / cpp_dispatch_exprs
167
# generated from NativeFunction / CppSignature
169
# return wrap(dispatch_abs(_r.tensor(0)));
172
# +--- dispatch_lambda_exprs
173
# binding PythonArgParserOutputExpr (python args)
174
# and DispatchLambdaArgument (c++ args)
177
# // aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
178
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
180
# +--- NativeFunction schema, out-variant
182
# auto dispatch_abs_out = [](Tensor out, const Tensor & self) -> Tensor {
183
# pybind11::gil_scoped_release no_gil;
184
# return at::abs_out(out, self);
186
# return wrap(dispatch_abs_out(_r.tensor(1), _r.tensor(0)));
190
# [Notes] python interface codegen
191
# The python dataclasses below are used used to generate both python binding code
192
# and pyi type hint signatures.
193
# In theory these two should look very similar, but there are number of differences
194
# in how pyi signatures vs. python_arg_parser signatures are generated.
195
# These differences have been encapsulated in signature_str() vs. signature_str_pyi()
196
# to display the full signatures, and argument_str() vs argument_str_pyi() to display arguments.
197
# For examples, only pyi signatures include return types.
200
@dataclass(frozen=True)
202
returns: tuple[Return, ...]
205
@dataclass(frozen=True)
211
# Used to generate the default init expr for some PythonArgParser outputs, e.g.:
213
# _r.layoutWithDefault(3, layout_from_backend(self.options().backend())))
214
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
216
# +--- default_init str
217
default_init: str | None
219
# Compute argument formal for python argument parsing.
220
# Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
221
def argument_str(self, *, method: bool = False, symint: bool = True) -> str:
223
argument_type_str(self.type, symint=symint)
224
.replace("const ", "")
229
# s/self/input/ outside method bindings
230
# [old codegen] TODO: remove this? doesn't rename in codegen, it's just
231
# for the parse string
232
if name == "self" and type_str in ["Tensor", "Number"] and not method:
236
if self.default is not None:
239
"::std::nullopt": "None",
240
"std::nullopt": "None",
242
}.get(self.default, self.default)
243
return f"{type_str} {name}={default}"
245
return f"{type_str} {name}"
247
def argument_str_pyi(
248
self, *, method: bool = False, deprecated: bool = False
250
type_str = argument_type_str_pyi(self.type)
253
# s/self/input/ outside method bindings
254
# [old codegen] TODO: remove this? doesn't rename in codegen, it's just
255
# for the parse string
256
if name == "self" and type_str == "Tensor" and not method and not deprecated:
259
if name == "from": # from is a Python keyword...
262
# pyi merges the _out and functional variants into the same signature, with an optional out arg
263
if name == "out" and type_str == "Tensor" and not deprecated:
264
type_str = "Optional[" + type_str + "]"
266
# pyi deprecated signatures don't get defaults for their out arg
267
treat_as_no_default = (
269
and isinstance(self, PythonOutArgument)
270
and self.default == "None"
274
if self.default is not None and not treat_as_no_default:
276
isinstance(self.type, ListType)
277
and self.type.elem == BaseType(BaseTy.int)
278
and self.default.startswith("{")
279
and self.default.endswith("}")
282
"(" + ", ".join(map(str.strip, self.default[1:-1].split(","))) + ")"
287
"::std::nullopt": "None",
288
"std::nullopt": "None",
290
"c10::MemoryFormat::Contiguous": "contiguous_format",
291
"QScheme::PER_TENSOR_AFFINE": "per_tensor_affine",
292
}.get(self.default, self.default)
293
return f"{name}: {type_str} = {default}"
295
return f"{name}: {type_str}"
298
@dataclass(frozen=True)
299
class PythonOutArgument(PythonArgument):
300
# In Python signature multiple output fields are packed into one 'out' argument.
301
# When binding to C++, it's first binded to a local 'out' variable:
302
# 'auto out = _r.tensorlist_n<2>(2);',
303
# then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc.
304
# TODO: maybe don't need keep scattered out fields for python signature?
305
outputs: tuple[PythonArgument, ...]
308
def from_outputs(outputs: tuple[PythonArgument, ...]) -> PythonOutArgument | None:
314
return PythonOutArgument(
315
name=outputs[0].name,
316
type=outputs[0].type,
322
if any(not a.type.is_tensor_like() for a in outputs):
323
raise RuntimeError(f"Unsupported output type: {outputs}")
324
return PythonOutArgument(
326
# TODO: shouldn't this be OptionalType[ListType[...]], since it defaults to None?
327
type=ListType(BaseType(BaseTy.Tensor), size),
332
raise AssertionError(r"Unexpected PythonOutArgument size")
335
@dataclass(frozen=True)
336
class PythonSignature:
337
# Base operator name, without inplace/outplace suffix.
340
# Positional arguments.
341
# TODO: create a dedicated SelfArgument type for 'self'?
342
input_args: tuple[PythonArgument, ...]
344
# Keyword arguments excluding the 'out' argument and scattered kwargs belonging
345
# to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc).
346
input_kwargs: tuple[PythonArgument, ...]
348
output_args: PythonOutArgument | None
350
# Return types, which are only used by pyi
351
returns: PythonReturns
353
# These are scattered kwargs arguments belonging to TensorOptions.
354
# When binding to C++, they are packed into a TensorOptions object 'options'.
355
# It's possible that the C++ signature doesn't take TensorOptions object (e.g.
356
# for out variant), in which case they will be used as scattered fields without
357
# being packed into 'options'.
358
# TODO: maybe create a PythonTensorOptionsArgument?
359
tensor_options_args: tuple[PythonArgument, ...]
361
# method or function signature?
365
def deprecated(self) -> bool:
369
self, *, skip_outputs: bool = False, skip_tensor_options: bool = False
370
) -> tuple[PythonArgument | PythonOutArgument, ...]:
371
result: list[PythonArgument | PythonOutArgument] = []
372
result.extend(self.input_args)
373
result.extend(self.input_kwargs)
374
if self.output_args is not None and not skip_outputs:
375
result.append(self.output_args)
376
if not skip_tensor_options:
377
result.extend(self.tensor_options_args)
380
def arguments_count(self) -> int:
381
return len(self.arguments())
383
def output_idx(self) -> int:
384
return len(self.input_args) + len(self.input_kwargs)
386
# [old codegen] Compute the Python function signature for argument parsing,
387
# as specified in torch/csrc/utils/python_arg_parser.h. WARNING:
388
# this is NOT the same type signature as specified by PEP 484
389
# as understood by mypy; our format was independently developed
390
# and has some quirks to make it more suitable specifically
393
# For a translation to mypy-valid type signatures, see
394
# signature_str_pyi().
395
def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
396
args = self.arguments(skip_outputs=skip_outputs)
397
schema_formals: list[str] = [
398
a.argument_str(method=self.method, symint=symint) for a in args
400
positional_argc = len(self.input_args)
401
if len(schema_formals) > positional_argc:
402
schema_formals.insert(positional_argc, "*")
404
return f'{self.name}({", ".join(schema_formals)})'
406
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
407
args = self.arguments(skip_outputs=skip_outputs)
408
schema_formals: list[str] = [
409
a.argument_str_pyi(method=self.method) for a in args
411
positional_argc = len(self.input_args)
412
if len(schema_formals) > positional_argc:
413
schema_formals.insert(positional_argc, "*")
415
# only pyi signatures include returns
416
returns_str = returns_str_pyi(self)
417
# pyi also includes self (with no typing/defaults) for methods
419
schema_formals.insert(0, "self")
420
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
422
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
423
# only pyi uses vararg signatures
424
args = self.arguments(skip_outputs=skip_outputs)
425
schema_formals: list[str] = [
426
a.argument_str_pyi(method=self.method) for a in args
428
# vararg only applies to pyi signatures. vararg variants are not generated for all signatures
429
num_args = self.arguments_count()
430
num_positionalargs = len(self.input_args)
432
have_vararg_version = False
434
vararg_type = args[0].type
436
isinstance(vararg_type, ListType)
437
and str(vararg_type.elem) in ["int", "SymInt"]
438
and num_positionalargs == 1
440
have_vararg_version = True
442
if not have_vararg_version:
445
# Below are the major changes in vararg vs. regular pyi signatures
446
# vararg signatures also omit the asterix
447
assert isinstance(vararg_type, ListType)
448
schema_formals[0] = (
449
"*" + args[0].name + ": " + argument_type_str_pyi(vararg_type.elem)
452
returns_str = returns_str_pyi(self)
453
# pyi also includes self (with no typing/defaults) for methods
455
schema_formals.insert(0, "self")
456
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
459
# The deprecated python signature involves some special logic, so create a
460
# dedicated data model to store these extra properties.
461
@dataclass(frozen=True)
462
class PythonSignatureDeprecated(PythonSignature):
463
# Schema for the deprecated function
464
deprecated_schema: FunctionSchema
466
# The deprecated signature might miss some arguments that the corresponding
467
# C++ signature expects. We need store the constant default values to pass in.
469
# [deprecate signature]: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2)
470
# [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
471
# [func call]: self.addmm(mat1, mat2, beta, 1)
472
# We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case.
473
deprecated_args_exprs: tuple[str, ...]
476
def deprecated(self) -> bool:
479
def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
481
PythonSignature.signature_str(
482
self, skip_outputs=skip_outputs, symint=symint
487
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
488
args = self.arguments(skip_outputs=skip_outputs)
489
schema_formals: list[str] = [
490
a.argument_str_pyi(method=self.method, deprecated=True) for a in args
492
positional_argc = len(self.input_args)
493
if len(schema_formals) > positional_argc:
494
schema_formals.insert(positional_argc, "*")
496
returns_str = returns_str_pyi(self)
497
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
499
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
500
# the codegen doesn't include vararg variants for deprecated signatures
504
# This struct is used to hold the PythonSignature and its corresponding
505
# NativeFunction BEFORE grouping base and out-variant functions.
506
# Why not store NativeFunction in PythonSignature or construct PythonSignature
507
# from NativeFunction? Because they are not 1-1 mapped.
508
# One native function could have both deprecated and non-deprecated python
509
# signatures - NativeFunction doesn't contain information to construct the
510
# deprecated python signature.
511
# One python signature is used to handle both the base and the out-variant
512
# function - see 'PythonSignatureGroup'.
513
@dataclass(frozen=True)
514
class PythonSignatureNativeFunctionPair:
515
signature: PythonSignature
516
function: NativeFunction
519
# We merge pairs of functions with signatures that are equivalent mod
520
# output arguments, and use a single entry in the python_arg_parser sig
521
# list for both (output arguments become optional).
522
@dataclass(frozen=True)
523
class PythonSignatureGroup:
524
# The signature used for Python argument parsing. The outplace signature
525
# is preferred if exists, because it can be used to parse inputs for both
526
# the out-place variant and the base version (with output omitted).
527
signature: PythonSignature
529
# The regular ATen declaration (e.g. conv2d)
532
# The out variant (e.g. conv2d_out)
533
outplace: NativeFunction | None
538
functional: PythonSignatureNativeFunctionPair,
539
out: PythonSignatureNativeFunctionPair | None,
540
) -> PythonSignatureGroup:
542
return PythonSignatureGroup(
543
signature=functional.signature,
544
base=functional.function,
548
# prefer the signature with optional out=... arguments because it's the
549
# superset that can be used to parse input for both base and outplace.
550
signature_kwargs = out.signature.__dict__.copy()
552
# Out overloads in C++ don't have TensorOptions arguments,
553
# so take these from the functional variant
555
"tensor_options_args"
556
] = functional.signature.tensor_options_args
558
return PythonSignatureGroup(
559
signature=type(out.signature)(**signature_kwargs),
560
base=functional.function,
561
outplace=out.function,
565
# C++ function dispatch is wrapped in a lambda function. The lambda function
566
# has almost the same signature as the C++ function, only with some small
567
# variants - see details below.
568
# This data model is used to represent arguments of the lambda function
570
@dataclass(frozen=True)
571
class DispatchLambdaArgument:
577
# To pass PyObjects arguments to C++ function (via the lambda wrapper),
578
# we need first convert PyObjects into simple C++ objects. This work
579
# is done by PythonArgParser.
580
# This data model is used to represent the output of PythonArgParser.
581
# It has 1-1 mapping with PythonArgument in PythonSignature.
582
@dataclass(frozen=True)
583
class PythonArgParserOutputExpr:
587
# RHS expression to reference PythonArgParser output.
590
# In some special cases we need create different expr, e.g.:
591
# '_r.isNone(1)' instead of '_r.tensor(1)'.
594
# The python argument it maps to.
595
argument: PythonArgument
598
def is_none_expr(self) -> str:
599
return f"_r.isNone({self.index})"
602
# To pass PythonArgParser output to the lambda wrapper, we need bind
603
# PythonArgParserOutputExpr to DispatchLambdaArgument.
604
# They are not always 1-1 mapped, e.g. scattered TensorOptions fields
605
# need be packed into a TensorOptions object, which is the argument
606
# that the lambda function wrapper takes.
607
@dataclass(frozen=True)
608
class DispatchLambdaArgumentExprs:
609
# The exprs that provide the binding for lambda arguments, e.g.:
611
# 'self' -> '_r.tensor(0)'
612
# 'min' -> 'out[0]' / 'min_indices' -> 'out[1]'
613
# 'options' -> 'options'
615
# It has 1-1 mapping with DispatchLambdaArgument.
618
# Special local inits, which might introduce new variables that
619
# the 'exprs' above reference, e.g.:
621
# 'auto out = _r.tensorlist_n<2>(2);'
626
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
630
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
633
def _cpp_signature(f: NativeFunction, *, method: bool = False) -> CppSignature:
634
return CppSignatureGroup.from_native_function(f, method=method).signature
637
def has_tensor_options(f: NativeFunction) -> bool:
638
return f.func.arguments.tensor_options is not None
641
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
645
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
648
# 'simple_type' was introduced by the old codegen, which is slightly
649
# different from the python schema type, e.g.: doesn't have '?' suffix
650
# for optional Tensor/TensorList; doesn't have '[size]' suffix for list type.
651
def argument_type_str(
652
t: Type, *, simple_type: bool = False, symint: bool = True
654
if isinstance(t, BaseType):
655
if t.name == BaseTy.Tensor:
657
elif t.name == BaseTy.int:
659
elif t.name == BaseTy.float:
661
elif t.name == BaseTy.str:
662
return "c10::string_view"
676
BaseTy.ConstQuantizerPtr,
679
# These python schema type names line up with their function schema names
682
elif isinstance(t, OptionalType):
683
if str(t.elem) == "Tensor":
684
# Is it desired to keep '?' for simple_type with new style dispatcher?
686
elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
688
elif isinstance(t, ListType):
689
size = t.size if not simple_type else None
690
if str(t.elem) == "bool":
691
assert t.size is not None
692
return f"::std::array<bool,{t.size}>"
693
elif str(t.elem) == "int":
694
return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
695
elif str(t.elem) == "SymInt":
698
f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef"
701
return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
702
elif str(t.elem) == "Tensor":
703
return f"TensorList[{size}]" if size is not None else "TensorList"
704
elif str(t.elem) == "Scalar":
705
return f"ScalarList[{size}]" if size is not None else "ScalarList"
706
elif str(t.elem) == "Tensor?":
708
return "c10::List<::std::optional<Tensor>>"
710
return "const c10::List<::std::optional<Tensor>> &"
711
elif str(t.elem) == "Dimname":
712
return f"DimnameList[{size}]" if size is not None else "DimnameList"
713
elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
714
return f"ArrayRef<{elem}>"
716
raise RuntimeError(f"unrecognized type {repr(t)}")
719
def argument_type_size(t: Type) -> int | None:
721
if l is not None and str(l.elem) != "bool":
727
def argument(a: Argument) -> PythonArgument:
728
return PythonArgument(
731
# TODO: directly translate a.default to python default
733
str(pythonify_default(cpp.default_expr(a.default, a.type, symint=False)))
734
if a.default is not None
741
# Generates a PythonSignature that can be used for either .pyi or PythonArgParser codegen
743
f: NativeFunction, *, method: bool = False, pyi: bool = False
745
return signature_from_schema(
746
f.func, category_override=f.category_override, method=method, pyi=pyi
750
def signature_from_schema(
751
func: FunctionSchema,
753
category_override: str | None,
754
method: bool = False,
757
args: list[Argument] = []
758
args.extend(func.arguments.pre_self_positional)
759
# Skip SelfArgument if this is method.
760
if not method and func.arguments.self_arg is not None:
761
args.append(func.arguments.self_arg.argument)
762
args.extend(func.arguments.post_self_positional)
763
args.extend(func.arguments.pre_tensor_options_kwarg_only)
764
# Skip TensorOptionsArguments. Python side TensorOptions
765
# arguments are created based on different rules - see below.
766
args.extend(func.arguments.post_tensor_options_kwarg_only)
767
args.extend(func.arguments.out)
769
input_arg_set = {a.name for a in func.arguments.flat_positional}
770
kwarg_only_set = {a.name for a in func.arguments.flat_kwarg_only}
771
out_arg_set = {a.name for a in func.arguments.out}
773
input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args)))
774
input_kwargs = tuple(
775
map(argument, filter(lambda a: a.name in kwarg_only_set, args))
777
outputs = tuple(map(argument, filter(lambda a: a.name in out_arg_set, args)))
779
# Reintroduce the scattered fields of TensorOptions for Python.
780
# Compared to the cpp counterpart, the python arguments have new property
781
# (default_init) and a new argument 'requires_grad', which require some
783
# [old codegen] TODO: because these aren't guaranteed to be 100% faithful
784
# to the original versions in the yaml, this recreation is a potential
785
# source of drift between eager and JIT. Pull this logic out to a shared place.
787
has_tensor_input_arg = any(
788
a.type.is_tensor_like() for a in func.arguments.flat_non_out
790
if any(a.name == "requires_grad" for a in func.schema_order_arguments()):
792
"argument named requires_grad is reserved, should not explicitly add it in the schema"
795
# [old codegen] this probably won't work if one of the returns is not a tensor,
796
# but it will produce a compile-time error that is obvious.
797
has_tensor_return = any(r.type.is_tensor_like() for r in func.returns)
799
name: str = cpp.name(func)
800
is_factory_function = category_override == "factory" or (
801
has_tensor_return and not has_tensor_input_arg
803
is_like_or_new_function = (
804
category_override in ("new", "like")
805
or name.startswith("new_")
806
or name.endswith("_like")
808
is_dummy_function = category_override == "dummy"
810
tensor_options_args: list[PythonArgument] = []
811
if (is_factory_function or is_like_or_new_function) and not is_dummy_function:
813
def topt_default_init(name: str) -> str | None:
814
topt_args = func.arguments.tensor_options
815
if topt_args is None:
817
a = getattr(topt_args, name)
818
if a.default is None or a.default == "None":
820
return cpp.default_expr(a.default, a.type, symint=False)
822
tensor_options_args.append(
825
type=OptionalType(BaseType(BaseTy.ScalarType)),
828
None if is_like_or_new_function else topt_default_init("dtype")
832
tensor_options_args.append(
835
type=OptionalType(BaseType(BaseTy.Layout)),
838
None if is_like_or_new_function else topt_default_init("layout")
842
tensor_options_args.append(
845
type=OptionalType(BaseType(BaseTy.Device)),
849
if is_like_or_new_function
851
topt_default_init("device")
852
or "torch::tensors::get_default_device()"
857
tensor_options_args.append(
860
type=OptionalType(BaseType(BaseTy.bool)),
865
tensor_options_args.append(
867
name="requires_grad",
868
type=OptionalType(BaseType(BaseTy.bool)),
874
returns = PythonReturns(returns=func.returns)
876
return PythonSignature(
877
name=str(func.name.name),
878
input_args=input_args,
879
input_kwargs=input_kwargs,
880
output_args=PythonOutArgument.from_outputs(outputs),
881
tensor_options_args=tuple(tensor_options_args),
887
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
891
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
894
def structseq_fieldnames(returns: tuple[Return, ...]) -> list[str]:
895
if len(returns) <= 1 or all(r.name is None for r in returns):
898
if any(r.name is None for r in returns):
899
# When building on Windows, `PyStructSequence_UnnamedField` could not be
900
# resolved by the linker for some reason, which cause error in building:
902
# python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
903
# PyStructSequence_UnnamedField
905
# Thus, at this point in time, we do not support unnamed
906
# fields in structseq; you must either name all fields,
908
raise ValueError("Unnamed field is not supported by codegen")
910
return [str(r.name) for r in returns]
913
def argument_type_str_pyi(t: Type) -> str:
915
if isinstance(t, OptionalType):
919
if isinstance(t, BaseType):
920
if t.name in [BaseTy.int, BaseTy.DeviceIndex]:
922
if t.name == BaseTy.SymInt:
923
ret = "Union[_int, SymInt]"
924
elif t.name == BaseTy.float:
926
elif t.name == BaseTy.str:
928
elif t.name == BaseTy.Scalar:
929
ret = "Union[Number, _complex]"
930
elif t.name == BaseTy.ScalarType:
932
elif t.name == BaseTy.bool:
934
elif t.name == BaseTy.QScheme:
936
elif t.name == BaseTy.Layout:
938
elif t.name == BaseTy.Device:
939
ret = "Optional[DeviceLikeType]"
940
elif t.name == BaseTy.MemoryFormat:
941
ret = "memory_format"
942
elif t.name == BaseTy.Dimname:
943
ret = "Union[str, ellipsis, None]"
944
elif t.name == BaseTy.Storage:
945
ret = "Union[Storage, UntypedStorage]"
946
elif t.name in [BaseTy.Tensor, BaseTy.Generator, BaseTy.Stream]:
947
# These python schema type names line up with their function schema names
950
elif isinstance(t, ListType):
951
if str(t.elem) == "int":
952
ret = "Union[_int, _size]" if t.size is not None else "_size"
953
elif t.is_tensor_like():
954
# TODO: this doesn't seem right...
955
# Tensor?[] currently translates to Optional[Union[Tuple[Tensor, ...], List[Tensor]]]
956
# It should probably translate to Union[Tuple[Optional[Tensor], ...], List[Optional[Tensor]]]
957
if isinstance(t.elem, OptionalType):
960
"Union[Tensor, Tuple[Tensor, ...], List[Tensor]]"
961
if t.size is not None
962
else "Union[Tuple[Tensor, ...], List[Tensor]]"
964
elif str(t.elem) == "float":
965
ret = "Sequence[_float]"
966
elif str(t.elem) == "SymInt" and t.size is not None:
967
elem = argument_type_str_pyi(t.elem)
968
ret = f"Union[{elem}, Sequence[{elem}]]"
970
elem = argument_type_str_pyi(t.elem)
971
ret = f"Sequence[{elem}]"
974
raise RuntimeError(f"unrecognized type {repr(t)}")
977
ret = "Optional[" + ret + "]"
982
def return_type_str_pyi(t: Type) -> str:
983
# Where arguments are open to accepting Union, return types should return
986
if isinstance(t, OptionalType):
987
inner = return_type_str_pyi(t.elem)
988
return f"Optional[{inner}]"
990
if isinstance(t, BaseType):
991
if t.name == BaseTy.Device:
993
elif t.name == BaseTy.Dimname:
994
ret = "Optional[str]"
996
return argument_type_str_pyi(t)
998
if isinstance(t, ListType):
999
inner = return_type_str_pyi(t.elem)
1000
return f"Tuple[{inner}, ...]"
1002
return argument_type_str_pyi(t)
1005
def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
1006
python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
1007
structseq_name = signature.name
1008
field_names = structseq_fieldnames(signature.returns.returns)
1010
# These types are structseq objects which act like named NamedTuples, but
1011
# the constructor acts like the constructor of tuple. Using typing.NamedTuple
1012
# does not allow us to override __init__.
1013
seq_type = f"Tuple[{', '.join(python_returns)}]"
1014
structseq_def_lines = [
1015
f"class {structseq_name}({seq_type}):",
1017
for name, typ in zip(field_names, python_returns):
1018
structseq_def_lines.extend(
1021
f" def {name}(self) -> {typ}: ...",
1024
structseq_def_lines.extend(
1026
f" def __new__(cls, sequence: {seq_type}): ...",
1027
f" n_fields: _int = {len(field_names)}",
1028
f" n_sequeunce_fields: _int = {len(field_names)}",
1029
" n_unnamed_fields: _int = 0",
1030
" def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing",
1031
"", # add an extra newline
1034
structseq_def = "\n".join(structseq_def_lines)
1037
# "class max(Tuple[Tensor, Tensor]):\n"
1039
# " def values(self) -> Tensor: ...\n"
1041
# " def indices(self) -> Tensor: ...\n"
1042
# " def __new__(cls, sequence: Tuple[Tensor, Tensor]): ...\n"
1043
# " n_fields: _int = 2",
1044
# " n_sequeunce_fields: _int = 2",
1045
# " n_unnamed_fields: _int = 0",
1046
# " def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing",
1048
return structseq_name, structseq_def
1052
def returns_str_pyi(signature: PythonSignature) -> str:
1053
field_names = structseq_fieldnames(signature.returns.returns)
1055
return f"torch.return_types.{signature.name}"
1057
python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
1058
if len(python_returns) > 1:
1059
return "Tuple[" + ", ".join(python_returns) + "]"
1060
if len(python_returns) == 1:
1061
return python_returns[0]
1065
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1067
# C++ Function Dispatch
1069
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1070
# This section provides APIs to generate the code that does C++ function
1071
# dispatch. The C++ function call is wrapped by a lambda function.
1074
# // aten::selu_(Tensor(a!) self) -> Tensor(a!)
1075
# auto dispatch_selu_ = [](Tensor self) -> Tensor {
1076
# pybind11::gil_scoped_release no_gil;
1077
# return at::selu_(self);
1080
# The lambda function's signature follows the C++ signature in common
1083
# // aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
1084
# [](const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor
1086
# For out variant the 'out' argument's type is changed from 'Tensor &'
1087
# to 'Tensor'. It's because when calling the lambda it passes in the
1088
# PythonArgParser output '_r.tensor(3)', which is stack allocated object
1089
# and needs to pass by value. Also see comments in 'dispatch_lambda_return_str()'.
1091
# // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
1092
# [](Tensor out, const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor
1094
# For multi-output case it can keep using reference type because the
1095
# PythonArgParser output has been unpacked to local variables, e.g.:
1097
# // aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *,
1098
# // Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
1099
# [](Tensor & max, Tensor & max_values, const Tensor & self, Dimname dim, bool keepdim) -> std::tuple<Tensor,Tensor>
1101
# For deprecated python signature, it should follow deprecated python arg order.
1102
# TODO: This is to keep same byte-for-byte result as the old codegen - maybe unnecessary?
1105
def dispatch_lambda_args(
1106
ps: PythonSignature, f: NativeFunction, symint: bool = True
1107
) -> tuple[DispatchLambdaArgument, ...]:
1108
if isinstance(ps, PythonSignatureDeprecated):
1109
schema = ps.deprecated_schema
1113
# Start with cpp arguments - dispatch lambda signature always include 'self'
1114
cpp_args = cpp.arguments(
1115
arguments=schema.arguments,
1119
cpp_no_default_args=f.cpp_no_default_args,
1121
out_args: set[str] = {a.name for a in schema.arguments.out}
1123
# Convert from cpp argument to lambda argument
1124
def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument:
1125
type_str = cpp_arg.type
1126
is_out_arg = cpp_arg.name in out_args
1127
if ps.method and cpp_arg.name == "self":
1128
# For method's 'self', we can use 'const Tensor &' and simply ignore mutability!
1129
type_str = "const at::Tensor &"
1131
# For other cases we need prevent dangling refs to temps (unless it's
1132
# unpacked scattered output)
1133
# The reason is explained in the comments above and in 'dispatch_lambda_return_str()'.
1134
# TODO: avoid this special handling?
1135
ensure_temp_safe = len(out_args) <= 1 or not is_out_arg
1136
if ensure_temp_safe:
1138
"at::Tensor &": "at::Tensor",
1139
}.get(type_str, type_str)
1140
return DispatchLambdaArgument(
1143
is_out_arg=is_out_arg,
1146
return tuple(map(dispatch_lambda_arg, cpp_args))
1149
# [old codegen] XXX: if you got here because of an assertion failure, it doesn't mean
1150
# it's enough to just extend the list here. Before you do this, make sure
1151
# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
1152
SUPPORTED_RETURN_TYPES = {
1154
"::std::tuple<at::Tensor,at::Tensor>",
1155
"::std::tuple<at::Tensor,at::Tensor,at::Tensor>",
1156
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
1157
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
1158
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
1159
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,int64_t>",
1160
"::std::tuple<at::Tensor,at::Tensor,double,int64_t>",
1161
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,int64_t>",
1162
"::std::tuple<at::Tensor,at::Tensor,double,at::Tensor,int64_t>",
1163
"::std::tuple<double,int64_t>",
1164
"::std::tuple<at::Tensor,::std::vector<at::Tensor>>",
1165
"::std::vector<at::Tensor>",
1166
# Needed for flash attention forw/backward
1167
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,c10::SymInt,c10::SymInt,at::Tensor,at::Tensor,at::Tensor>",
1181
def dispatch_lambda_return_str(f: NativeFunction) -> str:
1182
# [old codegen] Remove type annotation (e.g. 'Tensor' rather than 'Tensor &')
1183
# because the dispatch lambdas take mutable arguments *by value*, not
1184
# by reference. If you then return a reference to such an argument, you
1185
# will now have a pointer to a dangling stack entry. Not good.
1189
# auto dispatch_selu_ = [](Tensor self) -> Tensor { ...; return at::selu_(self); };
1194
# auto dispatch_selu_ = [](Tensor self) -> Tensor& { ...; return at::selu_(self); };
1197
# (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing
1198
# codegen looks like dispatch_selu_(_r.tensor(0)), and you can't take a
1199
# mutable reference to temporary. Maybe we could assign it to a
1201
returns_without_annotation = tuple(
1202
Return(r.name, r.type, None) for r in f.func.returns
1204
return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type()
1205
if return_str not in SUPPORTED_RETURN_TYPES:
1206
raise RuntimeError(f"{f.func.name} returns unsupported type {return_str}")
1210
def cpp_dispatch_target(f: NativeFunction) -> str:
1211
symint = f.func.has_symint()
1212
name = cpp.name(f.func, symint_overload=symint)
1213
if Variant.method in f.variants:
1214
return f"self.{name}"
1215
if Variant.function in f.variants:
1216
if has_tensor_options(f) or f.func.name.name.base.endswith("_like"):
1220
return f"{namespace}::{name}"
1221
raise RuntimeError(f"could not dispatch, neither function nor method: {f.func}")
1224
def cpp_dispatch_exprs(
1227
python_signature: PythonSignature | None = None,
1228
) -> tuple[str, ...]:
1229
cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments()
1231
exprs: tuple[str, ...] = ()
1232
if not isinstance(python_signature, PythonSignatureDeprecated):
1233
# By default the exprs are consistent with the C++ signature.
1234
exprs = tuple(a.name for a in cpp_args)
1236
# For deprecated python signature we may need fill in some constants.
1239
lambda n: n != "out" or f.func.is_out_fn(),
1240
python_signature.deprecated_args_exprs,
1244
if Variant.method in f.variants:
1245
exprs = tuple(filter("self".__ne__, exprs))
1250
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1252
# Python / C++ Args Binding
1254
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
1257
# We explicitly enumerate the PythonArgParser unpacking methods for all
1258
# supported types. This might be more verbose than necessary, partially
1259
# because of the irregularity of unpacking method naming, partially
1260
# because we want to mimic the old codegen behavior - to reject
1261
# unexpected and/or unsupported cases which the old codegen rejects.
1262
# For certain cases it is intentionally more restrictive than necessary,
1263
# e.g.: it doesn't accepts doublelist with definite size.
1264
def arg_parser_unpack_method(
1265
t: Type, default: str | None, default_init: str | None, *, symint: bool = True
1267
has_default_init = default_init is not None
1268
if has_default_init and str(t) not in (
1278
raise RuntimeError(f"type '{t}' does not supported unpacking with default")
1280
if isinstance(t, BaseType):
1288
# These unpack methods line up with their schema names
1289
return t.name.name.lower()
1290
elif t.name == BaseTy.ScalarType:
1291
return "scalartypeWithDefault" if has_default_init else "scalartype"
1292
elif t.name == BaseTy.Device:
1293
return "deviceWithDefault" if has_default_init else "device"
1294
elif t.name == BaseTy.DeviceIndex:
1296
elif t.name == BaseTy.int:
1298
elif t.name == BaseTy.SymInt:
1299
return "toSymInt" if symint else "toInt64"
1300
elif t.name == BaseTy.bool:
1301
return "toBoolWithDefault" if has_default_init else "toBool"
1302
elif t.name == BaseTy.float:
1304
elif t.name == BaseTy.str:
1306
elif t.name == BaseTy.Layout:
1307
return "layoutWithDefault" if has_default_init else "layout"
1308
elif t.name == BaseTy.MemoryFormat:
1309
return "memoryformat"
1311
elif isinstance(t, OptionalType):
1312
if str(t.elem) == "Tensor":
1313
return "optionalTensor"
1314
elif str(t.elem) == "Generator":
1316
elif str(t.elem) == "Dimname[]":
1317
return "toDimnameListOptional"
1318
elif not has_default_init and default in (
1324
# If default is None: append 'Optional' to elem's unpacking method
1326
arg_parser_unpack_method(t.elem, None, None, symint=symint) + "Optional"
1329
# Otherwise, load as underlying type with default
1330
return arg_parser_unpack_method(
1331
t.elem, default, default_init, symint=symint
1334
elif isinstance(t, ListType):
1335
if str(t.elem) == "Tensor":
1336
# accept and use definite size
1337
return f"tensorlist_n<{t.size}>" if t.size is not None else "tensorlist"
1338
elif str(t.elem) == "Tensor?":
1339
return "list_of_optional_tensors"
1340
elif str(t.elem) == "Dimname":
1341
# accept definite size
1342
return "dimnamelist"
1343
elif str(t.elem) == "int":
1344
# accept definite size
1346
elif str(t.elem) == "float":
1348
elif str(t.elem) == "SymInt":
1349
# accept definite size
1350
return "symintlist" if symint else "intlist"
1351
elif str(t.elem) == "Scalar":
1353
raise RuntimeError(f"type '{t}' is not supported by PythonArgParser")
1356
# Return RHS expression for python argument using PythonArgParser output.
1357
# e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)'
1358
def arg_parser_output_expr(
1359
arg_index: int, a: PythonArgument, *, symint: bool = True
1360
) -> PythonArgParserOutputExpr:
1361
has_default = a.default_init is not None
1362
unpack_method = arg_parser_unpack_method(
1363
t=a.type, default=a.default, default_init=a.default_init, symint=symint
1365
default = f", {a.default_init}" if has_default else ""
1366
expr = f"_r.{unpack_method}({arg_index}{default})"
1368
return PythonArgParserOutputExpr(
1376
# Returns a map with key = arg_name and value = PythonArgParserOutputExpr.
1377
def arg_parser_output_exprs(
1378
ps: PythonSignature, f: NativeFunction, *, symint: bool = True
1379
) -> dict[str, PythonArgParserOutputExpr]:
1382
for i, a in enumerate(ps.arguments())
1383
for e in (arg_parser_output_expr(i, a, symint=symint),)
1387
# argument name to type for scattered tensor options fields
1388
TENSOR_OPTIONS_FIELDS = {
1389
"dtype": "ScalarType?",
1390
"device": "Device?",
1391
"layout": "Layout?",
1392
"pin_memory": "bool?",
1393
"requires_grad": "bool?",
1397
# bind arg parser outputs (python args) with dispatch lambda arguments (c++ args).
1398
def dispatch_lambda_exprs(
1399
ps: PythonSignature, f: NativeFunction, *, symint: bool = True
1400
) -> DispatchLambdaArgumentExprs:
1401
# This method is to bind 'arg_parser_outputs' and 'lambda_args' by producing
1402
# 'inits' and 'lambda_args_exprs' for each lambda argument using arg parser
1404
arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
1405
lambda_args = dispatch_lambda_args(ps, f, symint=symint)
1406
inits: list[str] = []
1407
lambda_args_exprs: dict[str, str] = {}
1409
has_toptions = has_tensor_options(f)
1411
# 1. special inits/unpacking to provide binding exprs for lambda arguments.
1412
for a in ps.arguments(skip_tensor_options=True):
1414
arg_parser_expr = arg_parser_outputs[a.name].expr
1416
if has_toptions and name == "self":
1417
# TODO: why this needs to be special case?
1420
f"auto self = {arg_parser_expr};",
1423
lambda_args_exprs[name] = name
1425
isinstance(a, PythonOutArgument)
1426
and len(a.outputs) > 1
1427
and f.func.is_out_fn()
1431
f"auto out = {arg_parser_expr};",
1434
for i, out_arg in enumerate(a.outputs):
1435
lambda_args_exprs[out_arg.name] = f"out[{i}]"
1436
elif str(a.type) == "Dimname[]?":
1438
# TODO: make this part of something more general, or get rid of it.
1439
# optional<ArrayRef<T>> are special. The PythonArgParser returns an
1440
# optional<vector<T>>, which cannot be implicitly converted to
1441
# optional<ArrayRef<T>>. One needs to unwrap the optional and rewrap.
1444
f"auto __{name} = {arg_parser_expr};",
1445
f"::std::optional<DimnameList> {name} = __{name} ? ::std::make_optional(DimnameList(__{name}.value())) : ::std::nullopt;", # noqa: B950
1448
lambda_args_exprs[name] = name
1450
# default case - directly using PythonArgParser output expr
1451
lambda_args_exprs[name] = arg_parser_expr
1453
# method's self is passed directly to python binding, rather than parsed
1455
lambda_args_exprs["self"] = "self"
1457
# 2. special packing/checking for TensorOptions.
1458
tensor_options_args_names = [a.name for a in ps.tensor_options_args]
1460
if f.func.is_out_fn():
1461
raise RuntimeError(f"{f.func}: tensor options with output arg")
1462
for a in ps.tensor_options_args:
1463
if a.name not in TENSOR_OPTIONS_FIELDS:
1465
f"{f.func}: unrecognized tensor options field '{a.name}' in python binding arguments"
1467
if str(a.type) != TENSOR_OPTIONS_FIELDS.get(a.name):
1469
f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'"
1471
if not all(a in tensor_options_args_names for a in TENSOR_OPTIONS_FIELDS):
1473
f"{f.func}: incomplete tensor options args: {tensor_options_args_names}"
1478
const auto options = TensorOptions()
1479
.dtype({arg_parser_outputs['dtype'].expr})
1480
.device({arg_parser_outputs['device'].expr})
1481
.layout({arg_parser_outputs['layout'].expr})
1482
.requires_grad({arg_parser_outputs['requires_grad'].expr})
1483
.pinned_memory({arg_parser_outputs['pin_memory'].expr});
1484
torch::utils::maybe_initialize_device(options);
1487
lambda_args_exprs["options"] = "options"
1489
# 3. special case - access scattered TensorOptions fields without packing
1490
# TODO: maybe move to the generator side as it's not related to binding.
1491
if not has_toptions and tensor_options_args_names:
1492
if "dtype" in tensor_options_args_names:
1493
# we're an output-arg variant, check these args against output tensor
1494
if not f.func.is_out_fn():
1496
f"{f.func}: dtype in tensor_options_args without output arg, {ps} {ps.arguments}"
1498
if not all(a in tensor_options_args_names for a in ("layout", "device")):
1500
f"{f.func}: incomplete tensor options for output check"
1505
check_out_type_matches({arg_parser_outputs['out'].expr}, {arg_parser_outputs['dtype'].expr},
1506
{arg_parser_outputs['dtype'].is_none_expr}, {arg_parser_outputs['layout'].expr},
1507
{arg_parser_outputs['device'].expr}, {arg_parser_outputs['device'].is_none_expr});
1510
# we'll set requires_grad on outgoing tensor
1511
if "requires_grad" not in tensor_options_args_names:
1513
f'{f.func}: expected "requires_grad" in tensor_options_args absent, but found [{tensor_options_args_names}]'
1516
return DispatchLambdaArgumentExprs(
1517
exprs=tuple(lambda_args_exprs[a.name] for a in lambda_args),