6
from typing import List, Optional
8
import torchgen.api.python as python
9
from torchgen.api import cpp
11
from torchgen.api.types import CppSignatureGroup
12
from torchgen.context import with_native_function
13
from torchgen.gen import parse_native_yaml
14
from torchgen.model import NativeFunction, TensorOptionsArguments, Variant
15
from torchgen.utils import FileManager, mapMaybe
17
OPTIONAL_TYPE_PATTERN = re.compile(r"c10::optional<(.+)>")
18
TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)")
23
def fully_qualified_type(argument_type: str) -> str:
24
def maybe_optional_type(type: str, is_opt: bool) -> str:
25
return f"c10::optional<{type}>" if is_opt else type
27
opt_match = OPTIONAL_TYPE_PATTERN.match(argument_type)
28
is_opt = opt_match is not None
30
argument_type = argument_type[opt_match.start(1) : opt_match.end(1)]
31
match = TYPE_PATTERN.match(argument_type)
33
return maybe_optional_type(argument_type, is_opt)
34
index = match.start(1)
35
qualified_type = f"{argument_type[:index]}at::{argument_type[index:]}"
36
return maybe_optional_type(qualified_type, is_opt)
39
def gen_variable_factories(
40
out: str, native_yaml_path: str, tags_yaml_path: str, template_path: str
42
native_functions = parse_native_yaml(
43
native_yaml_path, tags_yaml_path
45
factory_functions = [fn for fn in native_functions if is_factory_function(fn)]
46
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
47
fm.write_with_template(
48
"variable_factories.h",
49
"variable_factories.h",
51
"generated_comment": "@"
52
+ f"generated from {fm.template_dir_for_comments()}/variable_factories.h",
54
f"#include <ATen/ops/{fn.root_name}.h>" for fn in factory_functions
56
"function_definitions": list(mapMaybe(process_function, factory_functions)),
62
def is_factory_function(f: NativeFunction) -> bool:
63
if Variant.function not in f.variants:
66
name = cpp.name(f.func)
67
has_tensor_options = python.has_tensor_options(f)
68
return has_tensor_options or name.endswith("_like")
72
def process_function(f: NativeFunction) -> Optional[str]:
73
name = cpp.name(f.func)
74
has_tensor_options = python.has_tensor_options(f)
75
is_factory = has_tensor_options or name.endswith("_like")
77
if Variant.function not in f.variants or not is_factory:
80
cpp_sigs = CppSignatureGroup.from_native_function(f, method=False)
81
sigs = [cpp_sigs.signature]
82
if cpp_sigs.symint_signature is not None:
83
sigs.append(cpp_sigs.symint_signature)
86
formals: List[str] = []
88
requires_grad = "false"
89
for arg in sig.arguments():
90
qualified_type = fully_qualified_type(arg.type)
92
formals.append(f"{qualified_type} {arg.name} = {arg.default}")
94
formals.append(f"{qualified_type} {arg.name}")
96
if isinstance(arg.argument, TensorOptionsArguments):
102
f"at::TensorOptions({arg.name}).requires_grad(c10::nullopt)"
105
requires_grad = f"{arg.name}.requires_grad()"
107
exprs.append(arg.name)
110
inline at::Tensor {sig.name()}({', '.join(formals)}) {{
111
at::AutoDispatchBelowADInplaceOrView guard;
112
return autograd::make_variable(at::{sig.name()}({', '.join(exprs)}), /*requires_grad=*/{requires_grad});