5
from collections import Counter, namedtuple
22
import torchgen.dest as dest
24
from torchgen.api.lazy import setValueT
25
from torchgen.api.types import BaseCppType
26
from torchgen.dest.lazy_ir import GenLazyIR, GenLazyNativeFuncDefinition, GenTSLazyIR
27
from torchgen.gen import get_grouped_native_functions, parse_native_yaml
29
from torchgen.model import NativeFunction, NativeFunctionsGroup, OperatorName
30
from torchgen.selective_build.selector import SelectiveBuilder
31
from torchgen.utils import concatMap, FileManager, NamespaceHelper
32
from torchgen.yaml_utils import YamlLoader
33
from .gen_backend_stubs import (
34
error_on_missing_kernels,
35
gen_dispatcher_registrations,
36
gen_dispatchkey_nativefunc_headers,
100
ParsedExternalYaml = namedtuple(
101
"ParsedExternalYaml",
102
["backend_key", "autograd_key", "cpp_namespace", "backend_indices", "full_codegen"],
106
def parse_native_functions_keys(
107
backend_yaml_path: str,
108
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
109
) -> Tuple[List[OperatorName], List[Any], List[OperatorName]]:
110
native_functions_map: Dict[OperatorName, NativeFunction] = {
113
lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()),
114
grouped_native_functions,
118
with open(backend_yaml_path) as f:
119
yaml_values = yaml.load(f, Loader=YamlLoader)
120
assert isinstance(yaml_values, dict)
122
full_codegen = yaml_values.pop("full_codegen", [])
123
non_native = yaml_values.pop("non_native", [])
124
ir_gen = yaml_values.pop("ir_gen", [])
125
assert isinstance(full_codegen, list)
126
assert isinstance(non_native, list)
127
assert isinstance(ir_gen, list)
128
full_codegen_opnames = [OperatorName.parse(name) for name in full_codegen]
129
ir_gen_opnames = [OperatorName.parse(name) for name in ir_gen]
130
return full_codegen_opnames, non_native, ir_gen_opnames
133
def validate_shape_inference_header(
134
shape_inference_hdr: str, expected_shape_infr_decls: List[str]
137
with open(shape_inference_hdr) as f:
138
shape_infr_decls = f.read()
139
shape_infr_decl_lines = set(shape_infr_decls.split("\n"))
141
raise AssertionError(
142
f"Unable to read from the specified shape_inference_hdr file: {shape_inference_hdr}"
145
shape_infr_regex = r"compute_shape_(\w+)"
146
actual_shape_infr_name_counts = Counter(
147
re.findall(shape_infr_regex, shape_infr_decls)
152
decl for decl in expected_shape_infr_decls if decl not in shape_infr_decl_lines
156
f"""Missing shape inference function.\n
157
Please add declare this function in {shape_inference_hdr}:\n
158
and implement it in the corresponding shape_inference.cpp file.\n
159
{os.linesep.join(missing_decls)}"""
164
def get_ltc_helper_fns() -> str:
166
at::Tensor to_meta(const at::Tensor& tensor) {
167
// undefined tensors can't be converted to the meta device, since they don't have sizes/strides
168
if (!tensor.defined()) return tensor;
169
auto out = at::native::empty_strided_meta_symint(tensor.sym_sizes(), tensor.sym_strides(), \
170
/*dtype=*/c10::make_optional(tensor.scalar_type()), /*layout=*/c10::make_optional(tensor.layout()), \
171
/*device=*/c10::make_optional(c10::Device(c10::kMeta)), /*pin_memory=*/c10::nullopt);
172
// needs to handle wrapped numbers, so dtype promotion works properly.
173
if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
174
out.unsafeGetTensorImpl()->set_wrapped_number(true);
178
c10::optional<at::Tensor> to_meta(const c10::optional<at::Tensor>& tensor) {
179
if (tensor.has_value()) {
180
return to_meta(*tensor);
185
std::vector<at::Tensor> to_meta(at::ITensorListRef t_list) {
186
std::vector<at::Tensor> outs;
187
outs.reserve(t_list.size());
188
for (const auto& tensor : t_list) {
189
outs.push_back(to_meta(tensor));
197
node_base: str = "Node"
198
node_base_hdr: Optional[str] = None
199
shape_inference_hdr: str = "torch/csrc/lazy/core/shape_inference.h"
200
tensor_class: str = "torch::lazy::LazyTensor"
201
tensor_class_hdr: str = "torch/csrc/lazy/core/tensor.h"
202
lazy_ir_generator: Type[GenLazyIR] = GenLazyIR
203
native_func_definition_generator: Type[
204
GenLazyNativeFuncDefinition
205
] = GenLazyNativeFuncDefinition
206
backend_name: str = "TorchScript"
210
parser = argparse.ArgumentParser(description="Generate Lazy Tensor backend files")
215
help="path to source yaml file containing operator external definitions",
217
parser.add_argument("-o", "--output-dir", "--output_dir", help="output directory")
219
"--dry-run", "--dry_run", type=bool, default=False, help="output directory"
226
help="path to the source C++ file containing kernel definitions",
229
"--gen-ts-lowerings",
230
"--gen_ts_lowerings",
232
help="Generate TorchScript lowerings in addition to Lazy IR and NativeFunctions",
238
default=default_args.node_base,
239
help="Name of backend specific custom Lazy IR Node base class",
245
default=default_args.node_base_hdr,
246
help="Path to header file defining custom Lazy IR Node base class",
249
"--shape-inference-hdr",
250
"--shape_inference_hdr",
252
default=default_args.shape_inference_hdr,
253
help="Path to header file defining custom Lazy shape inference functions",
259
default=default_args.tensor_class,
260
help="Name of backend specific custom Lazy Tensor class",
263
"--tensor-class-hdr",
264
"--tensor_class_hdr",
266
default=default_args.tensor_class_hdr,
267
help="Path to header file defining custom Lazy Tensor class",
273
default=default_args.backend_name,
274
help="Name of the backend to generate",
276
options = parser.parse_args()
279
torch_root = pathlib.Path(__file__).parent.parent.parent.absolute()
280
aten_path = str(torch_root / "aten" / "src" / "ATen")
281
lazy_ir_generator: Type[GenLazyIR] = default_args.lazy_ir_generator
282
if options.gen_ts_lowerings:
283
lazy_ir_generator = GenTSLazyIR
284
native_func_definition_generator: Type[
285
GenLazyNativeFuncDefinition
286
] = default_args.native_func_definition_generator
295
options.node_base_hdr,
296
options.tensor_class,
297
options.tensor_class_hdr,
298
options.shape_inference_hdr,
300
native_func_definition_generator,
301
options.backend_name,
305
def run_gen_lazy_tensor(
310
impl_path: Optional[str],
311
node_base: str = default_args.node_base,
312
node_base_hdr: Optional[str] = default_args.node_base_hdr,
313
tensor_class: str = default_args.tensor_class,
314
tensor_class_hdr: str = default_args.tensor_class_hdr,
315
shape_inference_hdr: str = default_args.shape_inference_hdr,
316
lazy_ir_generator: Type[GenLazyIR] = default_args.lazy_ir_generator,
317
native_func_definition_generator: Type[
318
GenLazyNativeFuncDefinition
319
] = default_args.native_func_definition_generator,
321
build_in_tree: bool = False,
324
per_operator_headers: bool = False,
325
backend_name: str = default_args.backend_name,
326
gen_forced_fallback_code: bool = False,
327
use_lazy_shape: bool = True,
330
backend_namespace: str = "torch::lazy",
331
get_tensorlist: str = "GetTensorList",
332
get_tensor_or_wrap_number: str = "GetLtcTensorOrCreateForWrappedNumber",
333
try_get_tensor: str = "TryGetLtcTensor",
334
metrics_counter: str = 'TORCH_LAZY_FN_COUNTER("lazy::")',
335
create_tensor: str = "LazyTensor::Create",
336
create_from_first_tensor: bool = False,
337
create_aten_from_ltc_tensor: str = "torch::lazy::CreateAtenFromLtcTensor",
338
tuple_aten_from_ltc_tensors: str = "torch::lazy::TupleAtenFromLtcTensors",
339
lazy_value_class: str = "torch::lazy::Value",
340
lazy_tensor_ptr: str = "LazyTensorPtr",
341
get_device_fn: str = "torch::lazy::GetBackendDevice",
343
lv_tokens = lazy_value_class.split("::")
344
lv_class = lv_tokens[-1]
345
lv_ns = "::".join(lv_tokens[:-1])
346
setValueT(BaseCppType(lv_ns, lv_class))
347
template_dir = os.path.join(aten_path, "templates")
349
def make_file_manager(install_dir: str) -> FileManager:
351
install_dir=install_dir, template_dir=template_dir, dry_run=dry_run
354
fm = make_file_manager(output_dir)
356
native_yaml_path = os.path.join(aten_path, "native/native_functions.yaml")
357
tags_yaml_path = os.path.join(aten_path, "native/tags.yaml")
358
parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
359
native_functions, backend_indices = (
360
parsed_yaml.native_functions,
361
parsed_yaml.backend_indices,
363
grouped_native_functions = get_grouped_native_functions(native_functions)
365
def sort_native_function(f: Union[NativeFunctionsGroup, NativeFunction]) -> str:
367
We sort the native function because of the note in concat_map_codegen.
368
TODO(alanwaketan): Remove this sorting hack once all ops are grouped properly.
370
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
371
return str(func.name.name)
373
grouped_native_functions = sorted(
374
grouped_native_functions, key=sort_native_function
377
parsed_backend_yaml = parse_backend_yaml(
378
source_yaml, grouped_native_functions, backend_indices
380
backend_key = parsed_backend_yaml.backend_key
381
autograd_key = parsed_backend_yaml.autograd_key
382
cpp_namespace = parsed_backend_yaml.cpp_namespace
383
backend_indices = parsed_backend_yaml.backend_indices
389
full_codegen, non_native, ir_gen = parse_native_functions_keys(
390
source_yaml, grouped_native_functions
393
def concat_map_codegen(
394
func: Callable[[NativeFunction], Sequence[str]],
395
xs: Iterable[Union[NativeFunctionsGroup, NativeFunction]],
396
ops_list: List[OperatorName] = full_codegen,
399
We code-gen for the functional variant, which is all we need for IR classes/lowerings/shape inferences, but we
400
only code-gen additional entries for the inplace variant for the native functions.
404
fs = list(x.functions()) if isinstance(x, NativeFunctionsGroup) else [x]
406
if f.func.name in ops_list:
409
selector = SelectiveBuilder.get_nop_selector()
411
assert backend_key is not None
412
class_name = backend_indices[backend_key].native_function_class_name()
414
if impl_path is not None:
415
error_on_missing_kernels(
425
""" Validate Shape Inference Definitions
427
Generated lazy native functions all perform shape inference, by first using a meta:: kernel
428
if available for that op, and otherwise using a 'compute_shape_{op}' function instead. The generator
429
knows the call signature for compute_shape_{op} because it matches the nativefunction (and meta::) signature,
430
so it just has to check whether the op is structured and generate a call for one or the other. It's up to the dev
431
to supply the missing compute_shape_{op} function, but the codegen at least warns you about this and provides
432
the expected signature which can be copy-pasted into shape_inference.h.
434
compute_shape_{op} functions are handwritten and should be replaced over time as ops get ported
435
to structured kernels.
437
See torch/csrc/lazy/core/shape_inference.cpp #READ THIS! for more information.
439
if shape_inference_hdr is not None:
440
expected_shape_infr_decls = list(
442
dest.GenLazyShapeInferenceDefinition(
443
backend_indices[backend_key], tensor_class
445
grouped_native_functions,
449
validate_shape_inference_header(shape_inference_hdr, expected_shape_infr_decls)
450
assert class_name is not None
456
gen_dispatchkey_nativefunc_headers(
461
grouped_native_functions,
468
for dispatch_key in (
469
[backend_key] if autograd_key is None else [backend_key, autograd_key]
471
gen_dispatcher_registrations(
476
grouped_native_functions,
480
build_in_tree=build_in_tree,
481
per_operator_headers=per_operator_headers,
482
backend_name=backend_name,
483
eager_registration=False,
487
ns_helper = NamespaceHelper(cpp_namespace)
488
fm.write_with_template(
489
f"{backend_key}NativeFunctions.cpp",
490
"DispatchKeyNativeFunctions.cpp",
498
"ATen/native/TensorConversions.h",
499
"ATen/NativeFunctions.h",
500
"ATen/CompositeExplicitAutogradNonFunctionalFunctions.h",
501
"ATen/MetaFunctions.h",
503
"ATen/native/CPUFallback.h",
504
"torch/csrc/lazy/core/ir_builder.h",
505
"torch/csrc/lazy/core/lazy_graph_executor.h",
506
"torch/csrc/lazy/core/metrics.h",
507
"torch/csrc/lazy/core/shape.h",
508
f"{output_dir}/{backend_key}NativeFunctions.h",
509
f"{output_dir}/LazyIr.h",
512
["torch/csrc/lazy/ts_backend/ts_eager_fallback.h"]
513
if gen_forced_fallback_code
517
"helper_fns": get_ltc_helper_fns(),
518
"native_functions_include": "",
519
"namespace_prologue": ns_helper.prologue,
520
"namespace_epilogue": ns_helper.epilogue,
521
"native_function_definitions": list(
523
native_func_definition_generator(
524
f"{backend_key}NativeFunctions",
525
backend_indices[backend_key],
527
gen_forced_fallback_code,
530
get_tensor_or_wrap_number,
534
create_from_first_tensor,
535
create_aten_from_ltc_tensor,
536
tuple_aten_from_ltc_tensors,
540
grouped_native_functions,
546
lazy_ir_obj = lazy_ir_generator(
547
backend_indices[backend_key], backend_name, node_base, use_lazy_shape
550
fm.write_with_template(
557
"ATen/core/Formatting.h",
558
"c10/core/ScalarType.h",
559
"c10/util/Optional.h",
560
"torch/csrc/lazy/core/hash.h",
561
"torch/csrc/lazy/core/ir.h",
562
"torch/csrc/lazy/core/shape.h",
566
"lazy_ir_inc": [f'#include "{node_base_hdr}"']
567
if node_base_hdr is not None
569
"ir_declarations": list(
571
lazy_ir_obj, grouped_native_functions, full_codegen + ir_gen
574
"namespace_prologue": ns_helper.prologue,
575
"namespace_epilogue": ns_helper.epilogue,
580
fm.write_with_template(
584
"lazy_non_native_ir_inc": [
587
"torch/csrc/lazy/core/ir.h",
588
"torch/csrc/lazy/core/ir_builder.h",
589
"torch/csrc/lazy/core/internal_ops/ltc_ops.h",
590
"torch/csrc/lazy/core/shape_inference.h",
592
+ ([node_base_hdr] if node_base_hdr else [])
595
"non_native_ir_nodes": dest.generate_non_native_lazy_ir_nodes(
596
non_native, lazy_ir_obj
598
"namespace_prologue": ns_helper.prologue,
599
"namespace_epilogue": ns_helper.epilogue,
604
if __name__ == "__main__":