pytorch

Форк
0
330 строк · 12.2 Кб
1
#!/bin/env python3
2

3
# Copyright (c) 2016-present, Facebook, Inc.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
##############################################################################
17

18
import sys
19
import yaml
20
import argparse
21
import os
22
from copy import deepcopy
23
from typing import Dict, List, Set
24

25
parser = argparse.ArgumentParser()
26
parser.add_argument("--template_dir", default=".", help="where template.h is")
27
parser.add_argument("--yaml_dir", default="aten/src/ATen/ATen",
28
                    help="where ATen yaml files are")
29
parser.add_argument("--output_prefix", default="", help="")
30
parser.add_argument(
31
    "--install_dir", default=".", help="where to put generated file")
32
parser.add_argument("--aten_root", default="", help="root directory of aten")
33
args, _ = parser.parse_known_args()
34

35
if args.aten_root:
36
    if not os.path.exists(args.aten_root):
37
        raise ValueError('aten_root ({}) does not exist'.format(
38
            args.aten_root))
39
    sys.path.insert(0, os.path.join(args.aten_root, '..'))
40
    from torchgen.code_template import CodeTemplate as CT
41
else:
42
    from torchgen.code_template import CodeTemplate as CT
43

44
OP_TEMPLATE = CT.from_file(
45
    os.path.join(args.template_dir, 'aten_op_template.h'))
46

47

48
try:
49
    # use faster C loader if available
50
    from yaml import CSafeLoader as Loader
51
except ImportError:
52
    from yaml import SafeLoader as Loader  # type: ignore[assignment, misc]
53

54

55
def write(filename, s):
56
    with open(filename, "w") as f:
57
        f.write(s)
58

59

60
def read(filename):
61
    with open(filename, "r") as f:
62
        return f.read()
63

64

65
def value_has_tensors(v):
66
    # Sparse shouldn't appear in public API, seems to be temporary bug
67
    return "Tensor" in v['dynamic_type'] and "Sparse" not in v['dynamic_type']
68

69

70
def value_is_tensor_type(v):
71
    return value_has_tensors(v) and v['dynamic_type'] not in TENSORLIST_TYPE
72

73
TENSORLIST_TYPE = [
74
    'at::TensorList',
75
    'const at::ITensorListRef &',
76
    'const c10::List<c10::optional<at::Tensor>> &',
77
]
78

79
# for each aten type, how do we handle a return value of that type?
80
RETURN_MAP = {
81
    'at::Tensor': 'assignTo(Output(${offset}),${output});',
82
    'at::Scalar': 'assignTo(Output(${offset}),${output}.type(), ${output});',
83
    'bool': 'assignToValue<int64_t>(Output(${offset}),${output});',
84
    'int64_t': 'assignToValue<int64_t>(Output(${offset}),${output});',
85
    '::std::vector<at::Tensor>': 'assignListStartingAt(${offset}, ${output});',
86
}
87

88
# for each non-Tensor aten argument, how to we read it from caffe2's
89
# attribute list. Most of these call runtime functions defined in the
90
# template class.
91
ARGUMENT_MAP = {
92
    'const at::Scalar &': 'at::Scalar ${arg} = readScalarAttribute("${arg}");',
93
    'bool': 'bool ${arg} = readAttribute<int64_t>("${arg}");',
94
    'int': 'int ${arg} = readAttribute<int64_t>("${arg}");',
95
    'double': 'double ${arg} = readAttribute<float>("${arg}");',
96
    'int64_t': 'int64_t ${arg} = readAttribute<int64_t>("${arg}");',
97
    'at::IntArrayRef': 'auto ${arg} = readIntArrayRef("${arg}");',
98
    '::std::array<bool,2>': 'auto ${arg} = readBoolMask<2>("${arg}");',
99
    '::std::array<bool,3>': 'auto ${arg} = readBoolMask<3>("${arg}");',
100
}
101

102
# for BC reasons we want to route some of the functions to different
103
# implementations
104
SPECIAL_IMPLEMENTATIONS = {
105
    'index': 'internal::index_with_uint8_handling',
106
}
107

108
def expand(o):
109
    num_defaults = sum(1 if 'default' in arg else 0 for arg in o['arguments'])
110
    results = [o]
111
    for i in range(0, num_defaults):
112
        # last num_default values should be default
113
        assert('default' in o['arguments'][-(i + 1)])
114
        v = deepcopy(o)
115
        v['arguments'] = v['arguments'][:-(i + 1)]
116
        results.append(v)
117
    return results
118

119

120
# filter the list of declarations removing things we cannot support
121
def supports(o, factory_methods):
122
    # Ignore all families (!) of functions that have TensorOptions (i.e. tensor factory methods).
123
    if o['name'] in factory_methods:
124
        if factory_methods[o['name']] == 0:
125
            print("Skipping {} because it is a factory method".format(o['name']))
126
        factory_methods[o['name']] += 1
127
        return False
128

129
    # skip all in-place operators for now since aten cannot Resize
130
    # caffe2 memory inside an operator
131
    if o['inplace']:
132
        return False
133

134
    # _out variants also work in-place on arguments taken as destinations
135
    # we also cannot handle these because aten cannot resize caffe2 Tensors
136
    if "_out" in o['name']:
137
        return False
138

139
    # skip if no return, previously it is 'void'
140
    if len(o['returns']) == 0:
141
        return False
142

143
    # skip return types we cannot handle
144
    for ret in o['returns']:
145
        if not value_has_tensors(ret) and ret['type'] not in RETURN_MAP:
146
            print("Skipping {} Because of Ret: {} ({})".format(
147
                  o['name'], ret['type'], ret['dynamic_type']))
148
            return False
149

150
    # skip arguments we cannot handle
151
    for arg in o['arguments']:
152
        if not value_has_tensors(arg) and arg['type'] not in ARGUMENT_MAP:
153
            print("Skipping {} Because of Arg: {} ({}) ".format(
154
                  o['name'], arg['type'], arg['dynamic_type']))
155
            return False
156
    return True
157

158

159
# template for each potential operator.
160
# each operator has an integer 'key' associated with it, and
161
# a lambda that defines the operator
162
# non-tensor attributes are created in ${initialization}
163
# and then saved as arguments to the lambda
164
# Inputs/Outputs are read inside the lambda
165
#
166
# each implementation is defined in a separate method annotated with
167
# C10_NOINLINE to avoid inlining into the ATenOp constructor, which would
168
# trigger pathological compile times.
169
IMPLEMENTATION_TEMPLATE = CT("""\
170
C10_NOINLINE void implementation_${key}() { // ${name}
171
    ${initialization}
172
    run_op = [=] {
173
        at::AutoDispatchBelowAutograd guard;
174
        ${statements}
175
        auto the_result = ${invocation};
176
        ${assignments}
177
        return true;
178
    };
179
}
180
""")
181

182
CASE_TEMPLATE = CT("""\
183
case ${key}: // ${name}
184
  implementation_${key}();
185
  break;
186
""")
187

188
ASSIGN_CHECK_SIZE_TEMPLATE = CT("""\
189
  if(OutputSize() > ${offset}) {${assignment}}
190
""")
191

192

193
def get_output(o, i):
194
    if len(o['returns']) == 1:
195
        return 'the_result'
196
    else:
197
        return '::std::get<{}>(the_result)'.format(i)
198

199

200
def attribute_names(o):
201
    return sorted([a['name'] for a in o['arguments'] if not value_has_tensors(a)])
202

203

204
def required_attribute_names(o):
205
    return sorted([a['name'] for a in o['arguments'] if not value_has_tensors(a) and 'default' not in a])
206

207

208
def self_as_first_argument(arguments):
209
    return ([a for a in arguments if a['name'] == 'self'] +
210
            [a for a in arguments if a['name'] != 'self'])
211

212

213
def get_num_inputs(o):
214
    args = 0
215
    for a in o['arguments']:
216
        if a['type'] in TENSORLIST_TYPE:
217
            return '*'
218
        elif value_has_tensors(a):
219
            args += 1
220
    return str(args)
221

222

223
def find_factory_methods(decls):
224
    factory_methods = {}
225
    for o in decls:
226
        if any(arg['dynamic_type'] == 'at::TensorOptions' for arg in o['arguments']):
227
            factory_methods[o['name']] = 0
228
    return factory_methods
229

230

231
def emit_assignments(o, env):
232
    for i, r in enumerate(o['returns']):
233
        t = RETURN_MAP[r['type'] if not value_is_tensor_type(r) else 'at::Tensor']
234
        assignment = CT(t).substitute(env, offset=i, output=get_output(o, i))
235
        check_size_assignment = ASSIGN_CHECK_SIZE_TEMPLATE.substitute(env, offset=i, assignment=assignment)
236

237
        env['assignments'].append(check_size_assignment)
238

239

240
if __name__ == '__main__':
241
    decls = yaml.load(read(os.path.join(args.yaml_dir, 'Declarations.yaml')), Loader=Loader)
242
    factory_methods = find_factory_methods(decls)
243
    filtered = [expanded for o in decls for expanded in expand(o) if supports(expanded, factory_methods)]
244
    top_env: Dict[str, List] = {
245
        'mappings': [],
246
        'implementations': [],
247
        'cases': [],
248
    }
249
    seen: Set[str] = set()
250
    key = 0
251
    for o in filtered:
252
        # [DESCRIPTORS]
253
        # each option is associated with a descriptor string that is used
254
        # to figure out which version of an op is being used:
255
        # The format is:
256
        #     opname-num_inputs-attribute_1-attribute2
257
        # Example:
258
        #  lerp-2-weight
259
        #  the operator lerp takes 2 arguments and has the attribute weight
260
        attr_names = attribute_names(o)
261
        num_inputs = get_num_inputs(o)
262
        descriptor = '-'.join([o['name']] + attr_names + [num_inputs])
263
        if descriptor in seen:
264
            continue
265
        seen.add(descriptor)
266

267
        # map from descriptor string to the integer key in the switch statements
268
        # that initializes the operators
269
        top_env['mappings'].append('{{ "{}", {} }},'.format(descriptor, key))
270
        env = {
271
            'name': o['name'],
272
            'statements': [],
273
            'arguments': [],
274
            'assignments': [],
275
            'initialization': [],
276
            'key': str(key),
277
        }
278

279
        if 'namespace' not in o['method_of'] and 'Tensor' not in o['method_of']:
280
            # methods on type like 'ones' or 'zeros' always take a
281
            # string attribute that is translated into the at::Type object
282
            # e.g. "Float" is at::kFloat
283
            assert('Type' in o['method_of'])
284

285
        static_tensor_inputs = sum(arg['type'] not in TENSORLIST_TYPE and value_is_tensor_type(arg) for arg in o['arguments'])
286
        has_tensorlist = any(arg['type'] in TENSORLIST_TYPE for arg in o['arguments'])
287
        if has_tensorlist:
288
            tensorlist_idx = [i for i, arg in enumerate(o['arguments']) if arg['type'] in TENSORLIST_TYPE][0]
289

290
        real_inputs = 0
291
        for i, arg in enumerate(o['arguments']):
292
            env['arguments'].append(arg['name'])
293
            # Pretend the flat argument list is a stack where the end is the top.
294
            view_length = 'InputSize()' if has_tensorlist and i < tensorlist_idx else static_tensor_inputs
295
            if arg['type'] == 'at::TensorList' or arg['type'] == 'const at::ITensorListRef &':
296
                # NOTE: do not advance real_inputs here. After this we will
297
                # switch to indexing the "stack" from the end
298
                env['statements'].append(
299
                    'auto {} = peekSlice({}, InputSize() - {}, InputSize());'
300
                    .format(arg['name'], real_inputs, static_tensor_inputs))
301
            elif arg['type'] == 'const c10::List<c10::optional<at::Tensor>> &':
302
                # NOTE: do not advance real_inputs here. After this we will
303
                # switch to indexing the "stack" from the end
304
                env['statements'].append(
305
                    'auto {} = peekSliceOptionals({}, InputSize() - {}, InputSize());'
306
                    .format(arg['name'], real_inputs, static_tensor_inputs))
307
            elif value_is_tensor_type(arg):
308
                # load tensor inputs from Caffe2
309
                env['statements'].append(
310
                    'auto {} = peek({}, {});'.format(arg['name'], real_inputs, view_length))
311
                real_inputs += 1
312
            else:
313
                init = CT(ARGUMENT_MAP[arg['type']]).substitute(env, arg=arg['name'])
314
                env['initialization'].append(init)
315

316
        emit_assignments(o, env)
317

318
        if o['name'] in SPECIAL_IMPLEMENTATIONS:
319
            env['invocation'] = "{}({})".format(SPECIAL_IMPLEMENTATIONS[o['name']], ','.join(env['arguments']))
320
        elif 'namespace' in o['method_of']:
321
            env['invocation'] = CT("at::${name}(${arguments})").substitute(env)
322
        else:
323
            assert('Tensor' in o['method_of'])
324
            env['invocation'] = "self.{}({})".format(
325
                o['name'], ', '.join(env['arguments'][1:]))
326

327
        top_env['implementations'].append(IMPLEMENTATION_TEMPLATE.substitute(env))
328
        top_env['cases'].append(CASE_TEMPLATE.substitute(env))
329
        key += 1
330
    write(os.path.join(args.install_dir, args.output_prefix + "aten_op.h"), OP_TEMPLATE.substitute(top_env))
331

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.