1
from __future__ import annotations
5
from dataclasses import dataclass
8
import torchgen.api.dispatcher as dispatcher
9
from torchgen.api.lazy import (
17
from torchgen.api.translate import translate
18
from torchgen.api.types import (
28
from torchgen.context import method_with_native_function
29
from torchgen.dest.lazy_ts_lowering import ts_lowering_body
30
from torchgen.model import (
43
def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
46
generate a c++ string for materializing an rvalue of that arg for passing into
47
a lazy Node constructor.
51
if isValueType(arg.lazy_type):
52
if isinstance(arg.lazy_type, BaseCType):
53
if arg.is_wrapped_scalar:
54
return f"node_{arg.name}"
55
elif arg.lazy_type.type is tensorListValueT:
56
return f"lazy_{arg.name}_tensorlist"
57
elif arg.is_symint_or_list:
58
return f"GetSymIntValue({arg.name})"
59
return f"lazy_{arg.name}->GetIrValue()"
60
elif isinstance(arg.lazy_type, OptionalCType):
61
if arg.is_symint_or_list:
64
return f"{arg.name} ? std::make_optional(GetSymIntValue(*{arg.name})) : ::std::nullopt"
65
elif arg.is_wrapped_scalar:
66
return f"node_{arg.name}"
69
f"std::make_optional(lazy_{arg.name}->GetIrValue()) : "
74
f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
81
if isinstance(arg.orig_type, ListType) and arg.orig_type.elem == BaseType(
85
return f"GetSymIntArrayRefValue({arg.name})"
87
return f"std::vector<int64_t>({arg.name}.begin(), {arg.name}.end())"
88
elif isinstance(arg.lazy_type, VectorCType) and isinstance(
89
arg.lazy_type.elem, BaseCType
91
return f"std::vector<{arg.lazy_type.elem.type}>({arg.name}.begin(), {arg.name}.end())"
93
isinstance(arg.lazy_type, OptionalCType)
94
and isinstance(arg.lazy_type.elem, VectorCType)
95
and isinstance(arg.lazy_type.elem.elem, BaseCType)
97
return f"torch::lazy::ToOptionalVector<{arg.lazy_type.elem.elem.type}>({arg.name})"
102
def node_ctor_inputs(schema: LazyIrSchema) -> str:
104
Produce a formatted string with the arguments as passed into the constructor of a node class.
107
node_ctor_arg_rvalue_string(arg) for arg in schema.filtered_args()
109
return ", ".join(node_ctor_values)
112
def gen_fallback_code(
113
schema: LazyIrSchema,
114
sig: DispatcherSignature | NativeSignature,
118
Generate code that falls back to eager conditioned on a predicate
120
dispatcher_sig = DispatcherSignature.from_schema(schema.func)
121
exprs = translate(sig.arguments(), dispatcher_sig.arguments())
122
fallback_args = ",\n ".join([a.expr for a in exprs])
123
if len(overload_name):
124
aten_op_str = f"ATEN_OP2({schema.aten_name}, {overload_name})"
126
aten_op_str = f"ATEN_OP({schema.aten_name})"
128
if (force_eager_fallback({aten_symbol(schema)})) {{
129
return at::native::call_fallback_fn_symint<<c_eager_fallback, {aten_op_str}>::call(
136
def aten_symbol(schema: LazyIrSchema) -> str:
137
missing_interned_strings = {
140
if schema.aten_name in missing_interned_strings:
141
return f'c10::Symbol::fromQualString("aten::{schema.aten_name}")'
143
if not schema.aten_name.startswith("at::"):
144
return f"at::aten::{schema.aten_name}"
146
return schema.aten_name
152
def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]:
153
context: list[Binding] = []
154
unwrapped_tensor_args: list[str] = []
155
for arg in sig.arguments():
156
if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like():
157
unwrapped_name = f"{arg.name}_meta"
158
unwrapped_tensor_args.append(
159
f"auto {unwrapped_name} = to_meta({arg.name});"
161
context.append(arg.with_name(unwrapped_name))
164
unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args)
165
return unwrap_tensor_args_str, context
168
@dataclass(frozen=True)
170
backend_index: BackendIndex
175
@method_with_native_function
176
def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]:
177
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
178
metadata = self.backend_index.get_kernel(
179
f.functional if isinstance(f, NativeFunctionsGroup) else f
181
schema = LazyIrSchema(
182
func, symint=metadata is not None and metadata.supports_symint()
184
return self.gen(schema)
188
def lowering_function(self, schema: LazyIrSchema) -> str:
191
def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
194
def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
195
return f"""bool CanBeReused({node_ctor_args}) const {{
199
def node_base_ctor_call(self, schema: LazyIrSchema) -> str:
200
value_args = schema.filtered_args(values=True, scalars=False)
203
base_ctor_value_args_list = []
204
for arg in value_args:
205
if isinstance(arg.lazy_type, (BaseCType, VectorCType)):
206
base_ctor_value_args_list.append(f"{arg.name}")
207
elif isinstance(arg.lazy_type, OptionalCType):
208
base_ctor_value_args_list.append(f"{arg.name}.value_or(kNullValue)")
210
raise AssertionError(
211
f"Unsupported type ({arg.lazy_type}) - add support if necessary"
213
base_ctor_value_args = ", ".join(base_ctor_value_args_list)
215
scalar_args = schema.filtered_args(values=False, scalars=True)
219
if schema.properties.ShapePrecompute:
220
shape_ctor_arg = "std::move(shapes),"
221
elif schema.properties.ShapeCompute:
222
shape_args = [a.name for a in value_args]
223
shape_args.extend(a.name for a in scalar_args)
224
shape_ctor_arg = f"compute_shape_{schema.name}({', '.join(shape_args)}),"
225
elif schema.properties.ShapeCache:
226
shape_args = [f"operand({i})" for i in range(len(value_args))]
227
shape_args.extend(a.name for a in scalar_args)
228
shape_ctor_arg = f"[&](){{ return compute_shape_{schema.name}({', '.join(shape_args)})[0]; }},"
232
scalar_hashes = ", ".join(f"{a.name}" for a in scalar_args)
234
return f"""{self.node_base}(
235
{schema.node_name}::ClassOpKind(),
236
OpList{{{base_ctor_value_args}}},
238
/* num_outputs */ {len(schema.returns)},
239
torch::lazy::MHash({scalar_hashes}))"""
241
def gen(self, schema: LazyIrSchema) -> list[str]:
242
opkind = schema.opkind or aten_symbol(schema)
246
all_args = schema.filtered_args()
247
scalar_args = schema.filtered_args(values=False, scalars=True)
249
ctor_args = [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args]
250
reuse_ctor_args = ", ".join(ctor_args)
251
if self.use_lazy_shape and schema.properties.ShapePrecompute:
252
ctor_args.append("std::vector<torch::lazy::Shape>&& shapes")
253
node_ctor_args = ", ".join(ctor_args)
255
scalar_initializers = ",\n ".join(
258
f"{a.name}({a.name}.has_value() ? ::std::make_optional(std::string(*{a.name})) : ::std::nullopt)"
259
if a.lazy_type.cpp_type() == "::std::optional<c10::string_view>"
260
else f"{a.name}({a.name})"
264
if len(scalar_initializers):
265
scalar_initializers = f",\n {scalar_initializers}"
266
scalar_decls = "\n ".join(
268
f"std::string {a.name};"
269
if a.lazy_type.cpp_type() == "c10::string_view"
270
else f"::std::optional<std::string> {a.name};"
271
if a.lazy_type.cpp_type() == "::std::optional<c10::string_view>"
272
else f"{a.lazy_type.cpp_type()} {a.name};"
278
for arg in schema.filtered_args(values=True, scalars=False)
279
if isinstance(arg.lazy_type, OptionalCType)
281
has_optional_decls = "\n ".join(
282
[f"bool has_{value}: 1;" for value in optional_values]
284
has_optional_defs = "\n ".join(
285
[f"has_{value} = !!{value};" for value in optional_values]
287
members_to_string = []
288
for arg in scalar_args:
289
if isinstance(arg.lazy_type, OptionalCType):
290
value = f"{arg.name}.value()"
292
value = '"torch.Generator()"'
293
members_to_string.append(
294
f"""if ({arg.name}.has_value()) {{
295
ss << ", {arg.name}=" << {value};
297
ss << ", {arg.name}=null";
301
members_to_string.append(f'ss << ", {arg.name}=" << {arg.name};')
302
members_to_string_str = "\n ".join(members_to_string)
306
class {schema.node_name} : public {self.node_base} {{
308
static torch::lazy::OpKind ClassOpKind() {{
309
return torch::lazy::OpKind({opkind});
312
{schema.node_name}({node_ctor_args})
313
: {self.node_base_ctor_call(schema)}{scalar_initializers}
318
std::string ToString() const override {{
319
std::stringstream ss;
320
ss << {self.node_base}::ToString();
321
{members_to_string_str}
325
{self.create_function(schema, reuse_ctor_args)}
327
{self.can_be_reused_function(schema, reuse_ctor_args)}
329
{self.lowering_function(schema)}
340
@dataclass(frozen=True)
341
class GenTSLazyIR(GenLazyIR):
342
def lowering_function(self, schema: LazyIrSchema) -> str:
344
torch::lazy::TSOpVector Lower(
345
std::shared_ptr<torch::jit::GraphFunction> function,
346
torch::lazy::TSLoweringContext* loctx) const override"""
348
if schema.properties.LowerDeclOnly:
349
return f"{signature};"
350
elif schema.properties.Lower:
351
return f"""{signature} {{
352
{ts_lowering_body(schema)}
358
def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
359
signature = f"static NodePtr Create({node_ctor_args})"
360
if schema.properties.CreateFnDeclOnly:
361
return f"{signature};"
362
elif not schema.properties.CreateFn:
364
return f"""{signature} {{
365
return ReuseOrMakeNode<{schema.node_name}>(data);
368
def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
369
signature = f"bool CanBeReused({node_ctor_args}) const"
370
if schema.properties.CanBeReusedDeclOnly:
371
return f"{signature};"
372
elif not schema.properties.CanBeReused:
374
value_comparison = []
375
for arg in itertools.chain(schema.positional_values, schema.keyword_values):
376
if isinstance(arg.lazy_type, OptionalCType):
377
value_comparison.append(
378
f"nullable_operand(i++) == {arg.name}.value_or(kNullValue)"
381
value_comparison.append(f"operand(i++) == {arg.name}")
382
for arg in itertools.chain(schema.positional_scalars, schema.keyword_scalars):
383
if isinstance(arg.lazy_type, OptionalCType):
384
value_comparison.append(
385
f"((!this->{arg.name}&&!{arg.name}) || (this->{arg.name}&&{arg.name} && *(this->{arg.name}) == *{arg.name}))"
388
value_comparison.append(f"this->{arg.name} == {arg.name}")
389
value_comparison_str = " &&\n ".join(value_comparison)
391
return f"""{signature} {{
393
return ({value_comparison_str});
397
@dataclass(frozen=True)
398
class GenLazyNativeFuncDefinition:
399
class_method_name: str
400
backend_index: BackendIndex
402
gen_forced_fallback_code: bool
403
backend_namespace: str
405
get_tensor_or_wrap_number: str
409
create_from_first_tensor: bool
410
create_aten_from_ltc_tensor: str
411
tuple_aten_from_ltc_tensors: str
415
def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str:
416
value_args = schema.filtered_args(values=True, scalars=False)
418
lazy_tensor_decls: list[str] = []
419
for arg in value_args:
420
if arg.is_wrapped_scalar:
421
if isinstance(arg.lazy_type, OptionalCType):
422
lazy_tensor_decls.append(
423
f"""auto node_{arg.name} = {arg.name} ?
424
std::make_optional(torch::lazy::LazyGraphExecutor::Get()->
425
GetIrValueForScalarFromCodegen(*{arg.name}, *common_device)):
429
lazy_tensor_decls.append(
430
f"""auto node_{arg.name} = torch::lazy::LazyGraphExecutor::Get()->
431
GetIrValueForScalarFromCodegen({arg.name}, *common_device);"""
433
elif arg.is_symint_or_list:
435
elif isinstance(arg.lazy_type, BaseCType):
436
if arg.lazy_type.type is tensorListValueT:
437
lazy_tensor_decls.append(
438
f"auto lazy_{arg.name}_tensorlist = "
439
f"{self.backend_namespace}::{self.get_tensorlist}({arg.name});"
442
lazy_tensor_decls.append(
443
f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
444
f"{self.backend_namespace}::{self.get_tensor_or_wrap_number}({arg.name}, *common_device);"
446
elif isinstance(arg.lazy_type, OptionalCType):
447
assert arg.lazy_type.elem == BaseCType(getValueT()), arg.lazy_type.elem
450
lazy_tensor_decls.append(
451
f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
452
f"{self.backend_namespace}::{self.try_get_tensor}({arg.name}.value_or(at::Tensor()));"
455
raise AssertionError(
456
f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
458
return ("\n ").join(lazy_tensor_decls)
460
def force_eager_fallback(
462
func: NativeFunction,
463
schema: LazyIrSchema,
464
metadata: BackendMetadata,
465
sig: DispatcherSignature | NativeSignature,
467
if self.gen_forced_fallback_code:
468
return gen_fallback_code(
469
schema, sig, overload_name=func.func.name.overload_name
473
def metrics(self, func: NativeFunction, schema: LazyIrSchema) -> str:
474
return f"{self.metrics_counter};"
476
def get_device(self, func: NativeFunction, schema: LazyIrSchema) -> str:
477
value_args = schema.filtered_args(values=True, scalars=False)
478
scalar_args = schema.filtered_args(values=False, scalars=True)
479
value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
480
optional_device = OptionalCType(BaseCType(deviceT))
482
a.name for a in scalar_args if a.lazy_type == optional_device
485
len(value_types_names) > 0 or len(optional_devices) > 0
486
), "Expected at least one Value or Device type"
488
f"{self.get_device_fn}({', '.join(value_types_names + optional_devices)})"
490
return f"""auto common_device = {get_device_str};
491
TORCH_INTERNAL_ASSERT(common_device);
494
def shape_inference(self, func: NativeFunction, schema: LazyIrSchema) -> str:
495
metadata = self.backend_index.get_kernel(func)
496
assert metadata is not None
497
all_args = schema.filtered_args()
498
returns_length = len(schema.returns)
509
is_view_copy_op = "view_copy" in func.tags
510
is_structured = func.structured or func.structured_delegate is not None
511
if is_structured or is_view_copy_op:
513
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};"""
514
if returns_length > 1:
516
def this_shape(i: int) -> str:
517
return f"torch::lazy::Shape(std::get<{i}>(out_meta).scalar_type(), std::get<{i}>(out_meta).sizes().vec())"
519
shapes_str = ",".join([this_shape(i) for i in range(returns_length)])
520
meta_out = "std::vector<torch::lazy::Shape> shapes{" + shapes_str + "};"
526
dispatcher_sig = DispatcherSignature.from_schema(func.func)
527
meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
531
meta_call_ctx, dispatcher_sig.arguments(), method=False
536
assert func.has_composite_explicit_autograd_non_functional_kernel
537
dispatch_ns = "compositeexplicitautogradnonfunctional"
540
aten_name = schema.aten_name
542
if func.func.has_symint() and metadata.supports_symint():
543
aten_name += "_symint"
545
{meta_conversion_str}
546
auto out_meta = at::{dispatch_ns}::{aten_name}({', '.join(meta_call_args)});
549
shape_sig = ComputeShapeSignature(
550
metadata.kernel, func, symint=metadata.supports_symint()
553
auto shapes = {shape_sig.shape_call};"""
556
TORCH_INTERNAL_ASSERT(shapes.size() == {returns_length});"""
559
func_schema_str = "aten::" + str(func.func)
561
if(torch::lazy::symbolicShapeEnabled()){{
562
std::vector<torch::jit::IValue> inputs = {{ {', '.join(str(a.name) for a in all_args)} }};
563
const char* schema_str = "{func_schema_str}";
564
applySymbolicShapesOnLT(schema_str, inputs, shapes);
569
def build_ir_node(self, func: NativeFunction, schema: LazyIrSchema) -> str:
570
node_ctor_input_str = node_ctor_inputs(schema)
571
return f"""torch::lazy::NodePtr node = torch::lazy::ReuseNode<{schema.node_name}>({node_ctor_input_str});
573
{self.shape_inference(func, schema)}
574
node = torch::lazy::MakeNode<{schema.node_name}>({node_ctor_input_str}, std::move(shapes));
579
def create_lazy_tensor(self, first_tensor_name: str | None = None) -> str:
581
if self.create_from_first_tensor:
584
first_tensor_name is not None
585
), "Requires first tensor to create lazy tensor"
586
return f"{first_tensor_name}.{self.create_tensor}"
587
return f"{self.backend_namespace}::{self.create_tensor}"
589
def return_aten_tensor(self, func: NativeFunction, schema: LazyIrSchema) -> str:
590
returns_length = len(schema.returns)
591
value_args = schema.filtered_args(values=True, scalars=False)
592
value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
593
first_tensor_name = value_types_names[0] if len(value_types_names) > 0 else None
594
bridge_str = f"""auto result = {self.create_aten_from_ltc_tensor}(
595
{self.create_lazy_tensor(first_tensor_name)}(std::move(node), *common_device));"""
597
if returns_length > 1:
599
len(value_types_names) > 0
600
), "Code below assumes there is at least one tensor arg"
601
bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors;
602
for (int i = 0; i < {returns_length}; i++) {{
603
lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({getValueT()}(node, i), *common_device));
605
auto result = {self.tuple_aten_from_ltc_tensors}<{returns_length}>(lazy_tensors);"""
607
if schema.name.name.inplace or func.func.is_out_fn():
608
assert returns_length == 1, (
609
"We assumed there was no such case where an op is an in-place variant "
610
f"and has tuple outputs, but got tuple of len {returns_length}."
612
bridge_str = f"""lazy_{first_tensor_name}->SetInPlaceIrValue(node);
613
auto& result = {first_tensor_name};"""
619
@method_with_native_function
620
def __call__(self, func: NativeFunction) -> list[str]:
621
sig = kernel_signature(func, self.backend_index)
622
metadata = self.backend_index.get_kernel(func)
623
assert metadata is not None
624
schema = LazyIrSchema(func.func, symint=metadata.supports_symint())
627
{sig.decl(name=f"{self.class_method_name}::{metadata.kernel}")} {{
628
{self.force_eager_fallback(func, schema, metadata, sig)}
629
{self.metrics(func, schema)}
630
{self.get_device(func, schema)}
631
{self.lazy_tensor_decls(func, schema)}
632
{self.build_ir_node(func, schema)}
633
{self.return_aten_tensor(func, schema)}
639
class ComputeShapeSignature:
641
Here we use the base name as the suffix of the signature to avoid generating for in-place variants.
644
def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool) -> None:
645
self.__schema = LazyIrSchema(f.func, symint=symint)
646
self.__dispatch_args = ", ".join(
647
[a.decl() for a in dispatcher.arguments(f.func, symint=symint)]
649
self.__call_args = ", ".join(
650
[f"{arg.name}" for arg in self.__schema.filtered_args(generator=True)]
652
self.__kernel_name = kernel_name
654
def __decl_suffix(self) -> str:
655
return f"{self.__kernel_name}({self.__dispatch_args})"
657
def __call_suffix(self) -> str:
658
return f"{self.__kernel_name}({self.__call_args})"
661
def shape_decl(self) -> str:
662
return f"TORCH_API std::vector<torch::lazy::Shape> compute_shape_{self.__decl_suffix()}"
665
def shape_call(self) -> str:
666
return f"torch::lazy::compute_shape_{self.__call_suffix()}"
669
@dataclass(frozen=True)
670
class GenLazyShapeInferenceDefinition:
671
backend_index: BackendIndex
674
@method_with_native_function
675
def __call__(self, f: NativeFunction) -> list[str]:
676
metadata = self.backend_index.get_kernel(f)
677
assert metadata is not None
680
is_view_copy_op = "view_copy" in f.tags
681
is_structured = f.structured or f.structured_delegate is not None
682
if is_structured or is_view_copy_op:
685
shape_sig = ComputeShapeSignature(
686
metadata.kernel, f, symint=metadata.supports_symint()
688
return ["\n".join([f"{shape_sig.shape_decl};"])]
691
def generate_non_native_lazy_ir_nodes(
692
non_native: list[dict[str, Any]], gen_lazy_ir: GenLazyIR
694
"""Generate the non-native lazy IR node classes"""
696
for op in non_native:
698
properties = LazyIrProperties("ShapeCache", "CanBeReused", "LowerDeclOnly")
699
for p in op.get("properties", []):
700
setattr(properties, p, True)
703
schema = LazyIrSchema(FunctionSchema.parse(op["func"]), properties, symint=True)
704
schema.opkind = op.get("opkind")
705
nodes.append(gen_lazy_ir.gen(schema)[0])