pytorch

Форк
0
467 строк · 16.1 Кб
1
from typing import List, Optional, Sequence, Set, Union
2

3
from torchgen import local
4
from torchgen.api.types import (
5
    ArgName,
6
    ArrayCType,
7
    ArrayRefCType,
8
    BaseCType,
9
    BaseTypeToCppMapping,
10
    Binding,
11
    boolT,
12
    ConstRefCType,
13
    CType,
14
    dimnameListT,
15
    intArrayRefT,
16
    iTensorListRefT,
17
    ListCType,
18
    longT,
19
    MutRefCType,
20
    NamedCType,
21
    OptionalCType,
22
    optionalIntArrayRefT,
23
    optionalSymIntArrayRefT,
24
    scalarT,
25
    SpecialArgName,
26
    symIntArrayRefT,
27
    SymIntT,
28
    tensorListT,
29
    tensorOptionsT,
30
    tensorT,
31
    TupleCType,
32
    VectorCType,
33
    voidT,
34
)
35
from torchgen.model import (
36
    Argument,
37
    Arguments,
38
    BaseTy,
39
    BaseType,
40
    FunctionSchema,
41
    ListType,
42
    NativeFunction,
43
    OptionalType,
44
    Return,
45
    SelfArgument,
46
    TensorOptionsArguments,
47
    Type,
48
)
49
from torchgen.utils import assert_never
50

51
# This file describes the translation of JIT schema to the public C++
52
# API, which is what people use when they call functions like at::add.
53
#
54
# Prominent characteristics of the C++ API:
55
#
56
#   - dtype, layout, device and pin_memory are collected into
57
#     a single C++ type TensorOptions  (the native functions API
58
#     also has this, but tensor options is really most relevant
59
#     for the C++ API; it makes calling kwarg factory functions
60
#     pleasant)
61
#
62
#   - defaulting lives here (in fact, the dispatcher is completely
63
#     oblivious of defaults!)
64
#
65
# BTW: policy on name collisions: we try not to have types with
66
# collisions, but functions are fair game to collide
67

68

69
def name(
70
    func: FunctionSchema,
71
    *,
72
    faithful_name_for_out_overloads: bool = False,
73
    symint_overload: bool = False,
74
) -> str:
75
    name = str(func.name.name)
76
    if symint_overload:
77
        name += "_symint"
78
    if func.is_out_fn():
79
        if faithful_name_for_out_overloads:
80
            name += "_outf"
81
        else:
82
            name += "_out"
83

84
    return name
85

86

87
# Translation of "value types" in JIT schema to C++ API type.  Value
88
# types look the same no matter if they are argument types or return
89
# types.  Returns None if the type in question is not a value type.
90
def valuetype_type(
91
    t: Type,
92
    *,
93
    binds: ArgName,
94
    remove_non_owning_ref_types: bool = False,
95
    symint: bool = False,
96
) -> Optional[NamedCType]:
97
    if isinstance(t, BaseType):
98
        if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
99
            return None
100
        elif str(t) == "SymInt":
101
            if symint:
102
                return NamedCType(binds, BaseCType(SymIntT))
103
            else:
104
                return NamedCType(binds, BaseCType(longT))
105
        if remove_non_owning_ref_types:
106
            if t.name == BaseTy.str:
107
                raise AssertionError(
108
                    "string ref->value conversion: not implemented yet"
109
                )
110
        # All other BaseType currently map directly to BaseCppTypes.
111
        return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
112
    elif isinstance(t, OptionalType):
113
        elem = valuetype_type(t.elem, binds=binds, symint=symint)
114
        if elem is None:
115
            return None
116
        return NamedCType(binds, OptionalCType(elem.type))
117
    elif isinstance(t, ListType):
118
        if str(t.elem) == "bool":
119
            assert t.size is not None
120
            return NamedCType(binds, ArrayCType(BaseCType(boolT), t.size))
121
        else:
122
            return None
123
    else:
124
        raise AssertionError(f"unrecognized type {repr(t)}")
125

126

127
# Translation of types occurring in JIT arguments to a C++ argument type.
128
# If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type.
129
# For example, we'll return std::vector<int> instead of IntArrayRef.
130
# See Note [translation from C++ reference to value types]
131
def argumenttype_type(
132
    t: Type,
133
    *,
134
    mutable: bool,
135
    binds: ArgName,
136
    remove_non_owning_ref_types: bool = False,
137
    symint: bool = False,
138
) -> NamedCType:
139
    # If it's a value type, do the value type translation
140
    r = valuetype_type(
141
        t,
142
        binds=binds,
143
        symint=symint,
144
        remove_non_owning_ref_types=remove_non_owning_ref_types,
145
    )
146
    if r is not None:
147
        return r
148

149
    if isinstance(t, BaseType):
150
        if t.name == BaseTy.Tensor:
151
            if mutable and not local.use_const_ref_for_mutable_tensors():
152
                return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
153
            else:
154
                return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
155
        elif t.name == BaseTy.Scalar:
156
            return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
157
        else:
158
            raise AssertionError(f"base type should have been value type {t}")
159
    elif isinstance(t, OptionalType):
160
        if str(t.elem) == "Tensor":
161
            if mutable and not local.use_const_ref_for_mutable_tensors():
162
                return NamedCType(
163
                    binds, MutRefCType(BaseCType(tensorT))
164
                )  # TODO: fix this discrepancy
165
            else:
166
                return NamedCType(
167
                    binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))
168
                )
169
        elif str(t.elem) == "Scalar":
170
            return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
171
        elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
172
            return NamedCType(binds, BaseCType(optionalIntArrayRefT))
173
        elif isinstance(t.elem, ListType) and str(t.elem.elem) == "SymInt":
174
            if symint:
175
                return NamedCType(binds, BaseCType(optionalSymIntArrayRefT))
176
            else:
177
                return NamedCType(binds, BaseCType(optionalIntArrayRefT))
178
        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
179
        return NamedCType(binds, OptionalCType(elem.type))
180
    elif isinstance(t, ListType):
181
        # TODO: remove these special cases, ArrayRef fallthrough works fine
182
        if str(t.elem) == "int":
183
            if remove_non_owning_ref_types:
184
                return NamedCType(binds, VectorCType(BaseCType(longT)))
185
            else:
186
                return NamedCType(binds, BaseCType(intArrayRefT))
187
        if str(t.elem) == "SymInt":
188
            if remove_non_owning_ref_types:
189
                if symint:
190
                    return NamedCType(binds, VectorCType(BaseCType(SymIntT)))
191
                else:
192
                    return NamedCType(binds, VectorCType(BaseCType(longT)))
193
            else:
194
                if symint:
195
                    return NamedCType(binds, BaseCType(symIntArrayRefT))
196
                else:
197
                    return NamedCType(binds, BaseCType(intArrayRefT))
198
        if str(t.elem) == "Tensor":
199
            if local.use_ilistref_for_tensor_lists():
200
                return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT)))
201
            else:
202
                return NamedCType(binds, BaseCType(tensorListT))
203
        elif str(t.elem) == "Scalar":
204
            return NamedCType(binds, ArrayRefCType(BaseCType(scalarT)))
205
        elif str(t.elem) == "Dimname":
206
            return NamedCType(binds, BaseCType(dimnameListT))
207
        elif str(t.elem) == "Tensor?":
208
            return NamedCType(
209
                binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
210
            )
211
        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
212
        return NamedCType(binds, ArrayRefCType(elem.type))
213
    else:
214
        raise AssertionError(f"unrecognized type {repr(t)}")
215

216

217
# Translate a JIT argument into its C++ type
218
def argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> NamedCType:
219
    return argumenttype_type(a.type, mutable=a.is_write, symint=symint, binds=binds)
220

221

222
# Translation of a (non-multi) return type from JIT to C++
223
# N.B: returntype_type returns a CType, not a NamedCType.
224
# This is mostly because of the mismatch between return types and return names.
225
# e.g. a function with a return type of 'void' has 0 return names,
226
# and a function with a return type of 'std::tuple' has >1 return name.
227
def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
228
    # placeholder is ignored
229
    # NB: symint is ALWAYS respected for return types.  So symint argument
230
    # here is IGNORED
231
    r = valuetype_type(t, binds="__placeholder__", symint=True)
232
    if r is not None:
233
        return r.type
234

235
    if isinstance(t, BaseType):
236
        if t.name == BaseTy.Tensor:
237
            if mutable:
238
                if local.use_const_ref_for_mutable_tensors():
239
                    return ConstRefCType(BaseCType(tensorT))
240
                else:
241
                    return MutRefCType(BaseCType(tensorT))
242
            else:
243
                # Note [Tensor Copy Returns]
244
                # Currently, we use "Argument.is_write" to determine
245
                # whether or not Tensor return types should be copies or references.
246
                # If that ever changes, take a look at other locations of this note!
247
                return BaseCType(tensorT)
248
        elif t.name == BaseTy.Scalar:
249
            return BaseCType(scalarT)
250
    elif isinstance(t, ListType):
251
        assert (
252
            not mutable
253
        ), "Native functions should never return a mutable tensor list. They should return void."
254
        elem = returntype_type(t.elem, mutable=False)
255
        assert t.size is None, f"fixed size list returns not supported: {t}"
256
        return VectorCType(elem)
257
    elif isinstance(t, OptionalType):
258
        elem = returntype_type(t.elem, mutable=mutable)
259
        if str(t.elem) == "Tensor":
260
            return OptionalCType(elem)
261

262
    raise AssertionError(f"unrecognized return type {t}")
263

264

265
# Translation of a single return to its C++ type
266
def return_type(r: Return, *, symint: bool = False) -> CType:
267
    return returntype_type(r.type, mutable=r.is_write, symint=symint)
268

269

270
# Translation of a full (possibly multi) return from JIT to its C++ type
271
def returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType:
272
    if len(rs) == 0:
273
        return BaseCType(voidT)
274
    elif len(rs) == 1:
275
        return return_type(rs[0], symint=symint)
276
    else:
277
        return TupleCType([return_type(r, symint=symint) for r in rs])
278

279

280
def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
281
    returns: List[str] = []
282
    for i, r in enumerate(f.func.returns):
283
        # If we have an inplace function, the return argument is
284
        # implicitly named self.
285
        # TODO: Consider incorporating this into the data model
286
        if f.func.name.name.inplace:
287
            assert i == 0, "illegal inplace function with multiple returns"
288
            name = "self"
289
        # If we are out function, the name is the name of the
290
        # corresponding output function (r.name will get recorded
291
        # in field_name later.)
292
        elif f.func.is_out_fn():
293
            name = f.func.arguments.out[i].name
294
        # If the return argument is explicitly named...
295
        elif r.name:
296
            name_conflict = any(
297
                r.name == a.name for a in f.func.schema_order_arguments()
298
            )
299
            if name_conflict and not f.func.is_out_fn():
300
                name = f"{r.name}_return"
301
            else:
302
                name = r.name
303
        # If there is no explicit name and no fallback name was passed in, we just name the output result,
304
        # unless it's a multi-return, in which case it's result0,
305
        # result1, etc (zero-indexed)
306
        else:
307
            name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}"
308
        returns.append(name)
309
    return returns
310

311

312
JIT_TO_CPP_DEFAULT = {
313
    "False": "false",
314
    "True": "true",
315
    "None": "c10::nullopt",  # UGH this one is type directed
316
    "Mean": "at::Reduction::Mean",
317
    "[]": "{}",
318
    "contiguous_format": "MemoryFormat::Contiguous",
319
    "long": "at::kLong",
320
}
321

322

323
# Convert a JIT default into C++ expression representing the default
324
def default_expr(d: str, t: Type, *, symint: bool) -> str:
325
    if d == "None" and str(t) == "Tensor?":
326
        return "{}"
327
    if isinstance(t, BaseType) and t.name is BaseTy.str:
328
        # Schema allows single quotes but C++ needs double
329
        if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
330
            s = ""
331
            i = 1
332
            while i + 1 < len(d):
333
                if d[i] != "\\":
334
                    if d[i] == '"':
335
                        s += '\\"'
336
                    else:
337
                        s += d[i]
338
                    i += 1
339
                else:
340
                    if d[i + 1] == "'":
341
                        s += "'"
342
                    else:
343
                        s += d[i : i + 2]
344
                    i += 2
345

346
            return f'"{s}"'
347

348
    if isinstance(t, OptionalType):
349
        if d == "None":
350
            return "c10::nullopt"
351

352
        return default_expr(d, t.elem, symint=symint)
353

354
    if isinstance(t, ListType):
355
        if d.startswith("[") and d.endswith("]"):
356
            return "{" + d[1:-1] + "}"
357
        elif symint and d.isdigit() and str(t.elem) == "SymInt":
358
            return f"c10::SymInt({d})"
359
        elif t.size is None:
360
            # NOTE: Sized lists can have scalar defaults
361
            raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
362

363
    return JIT_TO_CPP_DEFAULT.get(d, d)
364

365

366
# Convert an argument into its C++ API form
367

368

369
def argument(
370
    a: Union[Argument, TensorOptionsArguments, SelfArgument],
371
    *,
372
    cpp_no_default_args: Set[str],
373
    method: bool,
374
    faithful: bool,
375
    symint: bool = False,
376
    has_tensor_options: bool,
377
) -> List[Binding]:
378
    def sub_argument(
379
        a: Union[Argument, TensorOptionsArguments, SelfArgument]
380
    ) -> List[Binding]:
381
        return argument(
382
            a,
383
            cpp_no_default_args=cpp_no_default_args,
384
            method=method,
385
            faithful=faithful,
386
            symint=symint,
387
            has_tensor_options=has_tensor_options,
388
        )
389

390
    if isinstance(a, Argument):
391
        binds: ArgName
392
        if a.name == "memory_format" and has_tensor_options:
393
            binds = SpecialArgName.possibly_redundant_memory_format
394
        else:
395
            binds = a.name
396
        default: Optional[str] = None
397
        if a.name not in cpp_no_default_args and a.default is not None:
398
            default = default_expr(a.default, a.type, symint=symint)
399
        return [
400
            Binding(
401
                nctype=argument_type(a, binds=binds, symint=symint),
402
                name=a.name,
403
                default=default,
404
                argument=a,
405
            )
406
        ]
407
    elif isinstance(a, TensorOptionsArguments):
408
        if faithful:
409
            return (
410
                sub_argument(a.dtype)
411
                + sub_argument(a.layout)
412
                + sub_argument(a.device)
413
                + sub_argument(a.pin_memory)
414
            )
415
        else:
416
            default = None
417
            # Enforced by NativeFunction.__post_init__
418
            assert "options" not in cpp_no_default_args
419
            if all(x.default == "None" for x in a.all()):
420
                default = "{}"
421
            elif a.dtype.default == "long":
422
                default = "at::kLong"  # TODO: this is wrong
423
            return [
424
                Binding(
425
                    nctype=NamedCType("options", BaseCType(tensorOptionsT)),
426
                    name="options",
427
                    default=default,
428
                    argument=a,
429
                )
430
            ]
431
    elif isinstance(a, SelfArgument):
432
        if method:
433
            # Caller is responsible for installing implicit this in context!
434
            return []
435
        else:
436
            return sub_argument(a.argument)
437
    else:
438
        assert_never(a)
439

440

441
def arguments(
442
    arguments: Arguments,
443
    *,
444
    faithful: bool,
445
    symint: bool = False,
446
    method: bool,
447
    cpp_no_default_args: Set[str],
448
) -> List[Binding]:
449
    args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
450
    if faithful:
451
        args.extend(arguments.non_out)
452
        args.extend(arguments.out)
453
    else:
454
        args.extend(arguments.out)
455
        args.extend(arguments.non_out)
456
    return [
457
        r.no_default() if faithful else r
458
        for a in args
459
        for r in argument(
460
            a,
461
            faithful=faithful,
462
            symint=symint,
463
            method=method,
464
            has_tensor_options=arguments.tensor_options is not None,
465
            cpp_no_default_args=cpp_no_default_args,
466
        )
467
    ]
468

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.