22
from copy import deepcopy
23
from typing import Dict, List, Set
25
parser = argparse.ArgumentParser()
26
parser.add_argument("--template_dir", default=".", help="where template.h is")
27
parser.add_argument("--yaml_dir", default="aten/src/ATen/ATen",
28
help="where ATen yaml files are")
29
parser.add_argument("--output_prefix", default="", help="")
31
"--install_dir", default=".", help="where to put generated file")
32
parser.add_argument("--aten_root", default="", help="root directory of aten")
33
args, _ = parser.parse_known_args()
36
if not os.path.exists(args.aten_root):
37
raise ValueError('aten_root ({}) does not exist'.format(
39
sys.path.insert(0, os.path.join(args.aten_root, '..'))
40
from torchgen.code_template import CodeTemplate as CT
42
from torchgen.code_template import CodeTemplate as CT
44
OP_TEMPLATE = CT.from_file(
45
os.path.join(args.template_dir, 'aten_op_template.h'))
50
from yaml import CSafeLoader as Loader
52
from yaml import SafeLoader as Loader
55
def write(filename, s):
56
with open(filename, "w") as f:
61
with open(filename, "r") as f:
65
def value_has_tensors(v):
67
return "Tensor" in v['dynamic_type'] and "Sparse" not in v['dynamic_type']
70
def value_is_tensor_type(v):
71
return value_has_tensors(v) and v['dynamic_type'] not in TENSORLIST_TYPE
75
'const at::ITensorListRef &',
76
'const c10::List<c10::optional<at::Tensor>> &',
81
'at::Tensor': 'assignTo(Output(${offset}),${output});',
82
'at::Scalar': 'assignTo(Output(${offset}),${output}.type(), ${output});',
83
'bool': 'assignToValue<int64_t>(Output(${offset}),${output});',
84
'int64_t': 'assignToValue<int64_t>(Output(${offset}),${output});',
85
'::std::vector<at::Tensor>': 'assignListStartingAt(${offset}, ${output});',
92
'const at::Scalar &': 'at::Scalar ${arg} = readScalarAttribute("${arg}");',
93
'bool': 'bool ${arg} = readAttribute<int64_t>("${arg}");',
94
'int': 'int ${arg} = readAttribute<int64_t>("${arg}");',
95
'double': 'double ${arg} = readAttribute<float>("${arg}");',
96
'int64_t': 'int64_t ${arg} = readAttribute<int64_t>("${arg}");',
97
'at::IntArrayRef': 'auto ${arg} = readIntArrayRef("${arg}");',
98
'::std::array<bool,2>': 'auto ${arg} = readBoolMask<2>("${arg}");',
99
'::std::array<bool,3>': 'auto ${arg} = readBoolMask<3>("${arg}");',
104
SPECIAL_IMPLEMENTATIONS = {
105
'index': 'internal::index_with_uint8_handling',
109
num_defaults = sum(1 if 'default' in arg else 0 for arg in o['arguments'])
111
for i in range(0, num_defaults):
113
assert('default' in o['arguments'][-(i + 1)])
115
v['arguments'] = v['arguments'][:-(i + 1)]
121
def supports(o, factory_methods):
123
if o['name'] in factory_methods:
124
if factory_methods[o['name']] == 0:
125
print("Skipping {} because it is a factory method".format(o['name']))
126
factory_methods[o['name']] += 1
136
if "_out" in o['name']:
140
if len(o['returns']) == 0:
144
for ret in o['returns']:
145
if not value_has_tensors(ret) and ret['type'] not in RETURN_MAP:
146
print("Skipping {} Because of Ret: {} ({})".format(
147
o['name'], ret['type'], ret['dynamic_type']))
151
for arg in o['arguments']:
152
if not value_has_tensors(arg) and arg['type'] not in ARGUMENT_MAP:
153
print("Skipping {} Because of Arg: {} ({}) ".format(
154
o['name'], arg['type'], arg['dynamic_type']))
169
IMPLEMENTATION_TEMPLATE = CT("""\
170
C10_NOINLINE void implementation_${key}() { // ${name}
173
at::AutoDispatchBelowAutograd guard;
175
auto the_result = ${invocation};
182
CASE_TEMPLATE = CT("""\
183
case ${key}: // ${name}
184
implementation_${key}();
188
ASSIGN_CHECK_SIZE_TEMPLATE = CT("""\
189
if(OutputSize() > ${offset}) {${assignment}}
194
if len(o['returns']) == 1:
197
return '::std::get<{}>(the_result)'.format(i)
200
def attribute_names(o):
201
return sorted([a['name'] for a in o['arguments'] if not value_has_tensors(a)])
204
def required_attribute_names(o):
205
return sorted([a['name'] for a in o['arguments'] if not value_has_tensors(a) and 'default' not in a])
208
def self_as_first_argument(arguments):
209
return ([a for a in arguments if a['name'] == 'self'] +
210
[a for a in arguments if a['name'] != 'self'])
213
def get_num_inputs(o):
215
for a in o['arguments']:
216
if a['type'] in TENSORLIST_TYPE:
218
elif value_has_tensors(a):
223
def find_factory_methods(decls):
226
if any(arg['dynamic_type'] == 'at::TensorOptions' for arg in o['arguments']):
227
factory_methods[o['name']] = 0
228
return factory_methods
231
def emit_assignments(o, env):
232
for i, r in enumerate(o['returns']):
233
t = RETURN_MAP[r['type'] if not value_is_tensor_type(r) else 'at::Tensor']
234
assignment = CT(t).substitute(env, offset=i, output=get_output(o, i))
235
check_size_assignment = ASSIGN_CHECK_SIZE_TEMPLATE.substitute(env, offset=i, assignment=assignment)
237
env['assignments'].append(check_size_assignment)
240
if __name__ == '__main__':
241
decls = yaml.load(read(os.path.join(args.yaml_dir, 'Declarations.yaml')), Loader=Loader)
242
factory_methods = find_factory_methods(decls)
243
filtered = [expanded for o in decls for expanded in expand(o) if supports(expanded, factory_methods)]
244
top_env: Dict[str, List] = {
246
'implementations': [],
249
seen: Set[str] = set()
260
attr_names = attribute_names(o)
261
num_inputs = get_num_inputs(o)
262
descriptor = '-'.join([o['name']] + attr_names + [num_inputs])
263
if descriptor in seen:
269
top_env['mappings'].append('{{ "{}", {} }},'.format(descriptor, key))
275
'initialization': [],
279
if 'namespace' not in o['method_of'] and 'Tensor' not in o['method_of']:
283
assert('Type' in o['method_of'])
285
static_tensor_inputs = sum(arg['type'] not in TENSORLIST_TYPE and value_is_tensor_type(arg) for arg in o['arguments'])
286
has_tensorlist = any(arg['type'] in TENSORLIST_TYPE for arg in o['arguments'])
288
tensorlist_idx = [i for i, arg in enumerate(o['arguments']) if arg['type'] in TENSORLIST_TYPE][0]
291
for i, arg in enumerate(o['arguments']):
292
env['arguments'].append(arg['name'])
294
view_length = 'InputSize()' if has_tensorlist and i < tensorlist_idx else static_tensor_inputs
295
if arg['type'] == 'at::TensorList' or arg['type'] == 'const at::ITensorListRef &':
298
env['statements'].append(
299
'auto {} = peekSlice({}, InputSize() - {}, InputSize());'
300
.format(arg['name'], real_inputs, static_tensor_inputs))
301
elif arg['type'] == 'const c10::List<c10::optional<at::Tensor>> &':
304
env['statements'].append(
305
'auto {} = peekSliceOptionals({}, InputSize() - {}, InputSize());'
306
.format(arg['name'], real_inputs, static_tensor_inputs))
307
elif value_is_tensor_type(arg):
309
env['statements'].append(
310
'auto {} = peek({}, {});'.format(arg['name'], real_inputs, view_length))
313
init = CT(ARGUMENT_MAP[arg['type']]).substitute(env, arg=arg['name'])
314
env['initialization'].append(init)
316
emit_assignments(o, env)
318
if o['name'] in SPECIAL_IMPLEMENTATIONS:
319
env['invocation'] = "{}({})".format(SPECIAL_IMPLEMENTATIONS[o['name']], ','.join(env['arguments']))
320
elif 'namespace' in o['method_of']:
321
env['invocation'] = CT("at::${name}(${arguments})").substitute(env)
323
assert('Tensor' in o['method_of'])
324
env['invocation'] = "self.{}({})".format(
325
o['name'], ', '.join(env['arguments'][1:]))
327
top_env['implementations'].append(IMPLEMENTATION_TEMPLATE.substitute(env))
328
top_env['cases'].append(CASE_TEMPLATE.substitute(env))
330
write(os.path.join(args.install_dir, args.output_prefix + "aten_op.h"), OP_TEMPLATE.substitute(top_env))