pytorch

Форк
0
/
lazy_ir.py 
707 строк · 28.3 Кб
1
from __future__ import annotations
2

3
import itertools
4
from abc import ABC
5
from dataclasses import dataclass
6
from typing import Any
7

8
import torchgen.api.dispatcher as dispatcher
9
from torchgen.api.lazy import (
10
    getValueT,
11
    isValueType,
12
    LazyArgument,
13
    LazyIrProperties,
14
    LazyIrSchema,
15
    tensorListValueT,
16
)
17
from torchgen.api.translate import translate
18
from torchgen.api.types import (
19
    BaseCType,
20
    Binding,
21
    deviceT,
22
    DispatcherSignature,
23
    kernel_signature,
24
    NativeSignature,
25
    OptionalCType,
26
    VectorCType,
27
)
28
from torchgen.context import method_with_native_function
29
from torchgen.dest.lazy_ts_lowering import ts_lowering_body
30
from torchgen.model import (
31
    Argument,
32
    BackendIndex,
33
    BackendMetadata,
34
    BaseTy,
35
    BaseType,
36
    FunctionSchema,
37
    ListType,
38
    NativeFunction,
39
    NativeFunctionsGroup,
40
)
41

42

43
def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
44
    """
45
    Given a LazyArgument,
46
    generate a c++ string for materializing an rvalue of that arg for passing into
47
    a lazy Node constructor.
48
    """
49

50
    # TODO: Matching on CType seems wrong; should be matching on Type
51
    if isValueType(arg.lazy_type):
52
        if isinstance(arg.lazy_type, BaseCType):
53
            if arg.is_wrapped_scalar:
54
                return f"node_{arg.name}"
55
            elif arg.lazy_type.type is tensorListValueT:
56
                return f"lazy_{arg.name}_tensorlist"
57
            elif arg.is_symint_or_list:
58
                return f"GetSymIntValue({arg.name})"
59
            return f"lazy_{arg.name}->GetIrValue()"
60
        elif isinstance(arg.lazy_type, OptionalCType):
61
            if arg.is_symint_or_list:
62
                # TODO: I don't understand when you should put lazy_ in the name
63
                # or not
64
                return f"{arg.name} ? std::make_optional(GetSymIntValue(*{arg.name})) : ::std::nullopt"
65
            elif arg.is_wrapped_scalar:
66
                return f"node_{arg.name}"
67
            return (
68
                f"lazy_{arg.name} ? "
69
                f"std::make_optional(lazy_{arg.name}->GetIrValue()) : "
70
                "::std::nullopt"
71
            )
72
        else:
73
            raise AssertionError(
74
                f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
75
            )
76
    else:
77
        # NB: this is here because right now we aren't treating SymInt[] as a
78
        # value type; when we do this needs to move above
79
        # NB: we cannot test arg.lazy_type as we've already specified it is an
80
        # int64_t and so we cannot distinguish between SymInt and int64_t
81
        if isinstance(arg.orig_type, ListType) and arg.orig_type.elem == BaseType(
82
            BaseTy.SymInt
83
        ):
84
            if arg.symint:
85
                return f"GetSymIntArrayRefValue({arg.name})"
86
            else:
87
                return f"std::vector<int64_t>({arg.name}.begin(), {arg.name}.end())"
88
        elif isinstance(arg.lazy_type, VectorCType) and isinstance(
89
            arg.lazy_type.elem, BaseCType
90
        ):
91
            return f"std::vector<{arg.lazy_type.elem.type}>({arg.name}.begin(), {arg.name}.end())"
92
        elif (
93
            isinstance(arg.lazy_type, OptionalCType)
94
            and isinstance(arg.lazy_type.elem, VectorCType)
95
            and isinstance(arg.lazy_type.elem.elem, BaseCType)
96
        ):
97
            return f"torch::lazy::ToOptionalVector<{arg.lazy_type.elem.elem.type}>({arg.name})"
98
        else:
99
            return f"{arg.name}"
100

101

102
def node_ctor_inputs(schema: LazyIrSchema) -> str:
103
    """
104
    Produce a formatted string with the arguments as passed into the constructor of a node class.
105
    """
106
    node_ctor_values = [
107
        node_ctor_arg_rvalue_string(arg) for arg in schema.filtered_args()
108
    ]
109
    return ", ".join(node_ctor_values)
110

111

112
def gen_fallback_code(
113
    schema: LazyIrSchema,
114
    sig: DispatcherSignature | NativeSignature,
115
    overload_name: str,
116
) -> str:
117
    """
118
    Generate code that falls back to eager conditioned on a predicate
119
    """
120
    dispatcher_sig = DispatcherSignature.from_schema(schema.func)
121
    exprs = translate(sig.arguments(), dispatcher_sig.arguments())
122
    fallback_args = ",\n                ".join([a.expr for a in exprs])
123
    if len(overload_name):
124
        aten_op_str = f"ATEN_OP2({schema.aten_name}, {overload_name})"
125
    else:
126
        aten_op_str = f"ATEN_OP({schema.aten_name})"
127
    return f"""
128
        if (force_eager_fallback({aten_symbol(schema)})) {{
129
            return at::native::call_fallback_fn_symint<&ltc_eager_fallback, {aten_op_str}>::call(
130
                {fallback_args}
131
            );
132
        }}
133
"""
134

135

136
def aten_symbol(schema: LazyIrSchema) -> str:
137
    missing_interned_strings = {
138
        "sigmoid_backward",
139
    }
140
    if schema.aten_name in missing_interned_strings:
141
        return f'c10::Symbol::fromQualString("aten::{schema.aten_name}")'
142

143
    if not schema.aten_name.startswith("at::"):
144
        return f"at::aten::{schema.aten_name}"
145
    else:
146
        return schema.aten_name
147

148

149
# converts  all tensor-like arguments to meta tensors. Returns:
150
# (1) a string containing all of the logic that does the conversions.
151
# (2) a context, to be used by translate(), with all of the relevant bindings.
152
def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]:
153
    context: list[Binding] = []
154
    unwrapped_tensor_args: list[str] = []
155
    for arg in sig.arguments():
156
        if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like():
157
            unwrapped_name = f"{arg.name}_meta"
158
            unwrapped_tensor_args.append(
159
                f"auto {unwrapped_name} = to_meta({arg.name});"
160
            )
161
            context.append(arg.with_name(unwrapped_name))
162
        else:
163
            context.append(arg)
164
    unwrap_tensor_args_str = "\n        ".join(unwrapped_tensor_args)
165
    return unwrap_tensor_args_str, context
166

167

168
@dataclass(frozen=True)
169
class GenLazyIR(ABC):
170
    backend_index: BackendIndex
171
    backend_name: str
172
    node_base: str
173
    use_lazy_shape: bool
174

175
    @method_with_native_function
176
    def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]:
177
        func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
178
        metadata = self.backend_index.get_kernel(
179
            f.functional if isinstance(f, NativeFunctionsGroup) else f
180
        )
181
        schema = LazyIrSchema(
182
            func, symint=metadata is not None and metadata.supports_symint()
183
        )
184
        return self.gen(schema)
185

186
    # there is no lowering functionality generated unless this IR base class is subclassed and
187
    # implemented as a backend-specific node
188
    def lowering_function(self, schema: LazyIrSchema) -> str:
189
        return ""
190

191
    def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
192
        return ""
193

194
    def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
195
        return f"""bool CanBeReused({node_ctor_args}) const {{
196
    return false;
197
    }}"""
198

199
    def node_base_ctor_call(self, schema: LazyIrSchema) -> str:
200
        value_args = schema.filtered_args(values=True, scalars=False)
201
        # backends can customize the way the node base class constructor is called,
202
        # as long as all of its arguments can be generated from information available from the schema
203
        base_ctor_value_args_list = []
204
        for arg in value_args:
205
            if isinstance(arg.lazy_type, (BaseCType, VectorCType)):
206
                base_ctor_value_args_list.append(f"{arg.name}")
207
            elif isinstance(arg.lazy_type, OptionalCType):
208
                base_ctor_value_args_list.append(f"{arg.name}.value_or(kNullValue)")
209
            else:
210
                raise AssertionError(
211
                    f"Unsupported type ({arg.lazy_type}) - add support if necessary"
212
                )
213
        base_ctor_value_args = ", ".join(base_ctor_value_args_list)
214

215
        scalar_args = schema.filtered_args(values=False, scalars=True)
216

217
        # Shape construction.
218
        # Conditionally build shape depending on specified shape property
219
        if schema.properties.ShapePrecompute:
220
            shape_ctor_arg = "std::move(shapes),"
221
        elif schema.properties.ShapeCompute:
222
            shape_args = [a.name for a in value_args]
223
            shape_args.extend(a.name for a in scalar_args)
224
            shape_ctor_arg = f"compute_shape_{schema.name}({', '.join(shape_args)}),"
225
        elif schema.properties.ShapeCache:
226
            shape_args = [f"operand({i})" for i in range(len(value_args))]
227
            shape_args.extend(a.name for a in scalar_args)
228
            shape_ctor_arg = f"[&](){{ return compute_shape_{schema.name}({', '.join(shape_args)})[0]; }},"
229
        else:
230
            shape_ctor_arg = ""
231

232
        scalar_hashes = ", ".join(f"{a.name}" for a in scalar_args)
233

234
        return f"""{self.node_base}(
235
              {schema.node_name}::ClassOpKind(),
236
              OpList{{{base_ctor_value_args}}},
237
              {shape_ctor_arg}
238
              /* num_outputs */ {len(schema.returns)},
239
              torch::lazy::MHash({scalar_hashes}))"""
240

241
    def gen(self, schema: LazyIrSchema) -> list[str]:
242
        opkind = schema.opkind or aten_symbol(schema)
243

244
        # for now, we just want one IR class decl and soon after also the method defs
245
        # and we use the functional version not out/inplace.
246
        all_args = schema.filtered_args()
247
        scalar_args = schema.filtered_args(values=False, scalars=True)
248

249
        ctor_args = [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args]
250
        reuse_ctor_args = ", ".join(ctor_args)
251
        if self.use_lazy_shape and schema.properties.ShapePrecompute:
252
            ctor_args.append("std::vector<torch::lazy::Shape>&& shapes")
253
        node_ctor_args = ", ".join(ctor_args)
254

255
        scalar_initializers = ",\n        ".join(
256
            [
257
                # This code is just special casing the mapping from string_view -> strings
258
                f"{a.name}({a.name}.has_value() ? ::std::make_optional(std::string(*{a.name})) : ::std::nullopt)"
259
                if a.lazy_type.cpp_type() == "::std::optional<c10::string_view>"
260
                else f"{a.name}({a.name})"
261
                for a in scalar_args
262
            ]
263
        )
264
        if len(scalar_initializers):
265
            scalar_initializers = f",\n        {scalar_initializers}"
266
        scalar_decls = "\n  ".join(
267
            [
268
                f"std::string {a.name};"
269
                if a.lazy_type.cpp_type() == "c10::string_view"
270
                else f"::std::optional<std::string> {a.name};"
271
                if a.lazy_type.cpp_type() == "::std::optional<c10::string_view>"
272
                else f"{a.lazy_type.cpp_type()} {a.name};"
273
                for a in scalar_args
274
            ]
275
        )
276
        optional_values = [
277
            arg.name
278
            for arg in schema.filtered_args(values=True, scalars=False)
279
            if isinstance(arg.lazy_type, OptionalCType)
280
        ]
281
        has_optional_decls = "\n  ".join(
282
            [f"bool has_{value}: 1;" for value in optional_values]
283
        )
284
        has_optional_defs = "\n    ".join(
285
            [f"has_{value} = !!{value};" for value in optional_values]
286
        )
287
        members_to_string = []
288
        for arg in scalar_args:
289
            if isinstance(arg.lazy_type, OptionalCType):
290
                value = f"{arg.name}.value()"
291
                if arg.is_generator:
292
                    value = '"torch.Generator()"'
293
                members_to_string.append(
294
                    f"""if ({arg.name}.has_value()) {{
295
      ss << ", {arg.name}=" << {value};
296
    }} else {{
297
      ss << ", {arg.name}=null";
298
    }}"""
299
                )
300
            else:
301
                members_to_string.append(f'ss << ", {arg.name}=" << {arg.name};')
302
        members_to_string_str = "\n    ".join(members_to_string)
303

304
        return [
305
            f"""\
306
class {schema.node_name} : public {self.node_base} {{
307
 public:
308
  static torch::lazy::OpKind ClassOpKind() {{
309
    return torch::lazy::OpKind({opkind});
310
  }}
311

312
  {schema.node_name}({node_ctor_args})
313
      : {self.node_base_ctor_call(schema)}{scalar_initializers}
314
  {{
315
    {has_optional_defs}
316
  }}
317

318
  std::string ToString() const override {{
319
    std::stringstream ss;
320
    ss << {self.node_base}::ToString();
321
    {members_to_string_str}
322
    return ss.str();
323
  }}
324

325
  {self.create_function(schema, reuse_ctor_args)}
326

327
  {self.can_be_reused_function(schema, reuse_ctor_args)}
328

329
  {self.lowering_function(schema)}
330

331
  {scalar_decls}
332
  {has_optional_decls}
333

334
}};
335

336
""",
337
        ]
338

339

340
@dataclass(frozen=True)
341
class GenTSLazyIR(GenLazyIR):
342
    def lowering_function(self, schema: LazyIrSchema) -> str:
343
        signature = """
344
  torch::lazy::TSOpVector Lower(
345
      std::shared_ptr<torch::jit::GraphFunction> function,
346
      torch::lazy::TSLoweringContext* loctx) const override"""
347

348
        if schema.properties.LowerDeclOnly:
349
            return f"{signature};"
350
        elif schema.properties.Lower:
351
            return f"""{signature} {{
352
    {ts_lowering_body(schema)}
353
  }}
354
            """
355
        else:
356
            return ""
357

358
    def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
359
        signature = f"static NodePtr Create({node_ctor_args})"
360
        if schema.properties.CreateFnDeclOnly:
361
            return f"{signature};"
362
        elif not schema.properties.CreateFn:
363
            return ""
364
        return f"""{signature} {{
365
    return ReuseOrMakeNode<{schema.node_name}>(data);
366
  }}"""
367

368
    def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
369
        signature = f"bool CanBeReused({node_ctor_args}) const"
370
        if schema.properties.CanBeReusedDeclOnly:
371
            return f"{signature};"
372
        elif not schema.properties.CanBeReused:
373
            return ""
374
        value_comparison = []
375
        for arg in itertools.chain(schema.positional_values, schema.keyword_values):
376
            if isinstance(arg.lazy_type, OptionalCType):
377
                value_comparison.append(
378
                    f"nullable_operand(i++) == {arg.name}.value_or(kNullValue)"
379
                )
380
            else:
381
                value_comparison.append(f"operand(i++) == {arg.name}")
382
        for arg in itertools.chain(schema.positional_scalars, schema.keyword_scalars):
383
            if isinstance(arg.lazy_type, OptionalCType):
384
                value_comparison.append(
385
                    f"((!this->{arg.name}&&!{arg.name}) || (this->{arg.name}&&{arg.name} && *(this->{arg.name}) == *{arg.name}))"
386
                )
387
            else:
388
                value_comparison.append(f"this->{arg.name} == {arg.name}")
389
        value_comparison_str = " &&\n        ".join(value_comparison)
390

391
        return f"""{signature} {{
392
    size_t i = 0;
393
    return ({value_comparison_str});
394
  }}"""
395

396

397
@dataclass(frozen=True)
398
class GenLazyNativeFuncDefinition:
399
    class_method_name: str
400
    backend_index: BackendIndex
401
    tensor_class: str
402
    gen_forced_fallback_code: bool
403
    backend_namespace: str
404
    get_tensorlist: str
405
    get_tensor_or_wrap_number: str
406
    try_get_tensor: str
407
    metrics_counter: str
408
    create_tensor: str
409
    create_from_first_tensor: bool
410
    create_aten_from_ltc_tensor: str
411
    tuple_aten_from_ltc_tensors: str
412
    lazy_tensor_ptr: str
413
    get_device_fn: str
414

415
    def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str:
416
        value_args = schema.filtered_args(values=True, scalars=False)
417
        # Generates lazy_{name} variables for LazyTensors wrapping input tensors
418
        lazy_tensor_decls: list[str] = []
419
        for arg in value_args:
420
            if arg.is_wrapped_scalar:
421
                if isinstance(arg.lazy_type, OptionalCType):
422
                    lazy_tensor_decls.append(
423
                        f"""auto node_{arg.name} = {arg.name} ?
424
                std::make_optional(torch::lazy::LazyGraphExecutor::Get()->
425
                    GetIrValueForScalarFromCodegen(*{arg.name}, *common_device)):
426
                ::std::nullopt;"""
427
                    )
428
                else:
429
                    lazy_tensor_decls.append(
430
                        f"""auto node_{arg.name} = torch::lazy::LazyGraphExecutor::Get()->
431
                            GetIrValueForScalarFromCodegen({arg.name}, *common_device);"""
432
                    )
433
            elif arg.is_symint_or_list:
434
                continue  # values are extracted in isValueType
435
            elif isinstance(arg.lazy_type, BaseCType):
436
                if arg.lazy_type.type is tensorListValueT:
437
                    lazy_tensor_decls.append(
438
                        f"auto lazy_{arg.name}_tensorlist = "
439
                        f"{self.backend_namespace}::{self.get_tensorlist}({arg.name});"
440
                    )
441
                else:
442
                    lazy_tensor_decls.append(
443
                        f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
444
                        f"{self.backend_namespace}::{self.get_tensor_or_wrap_number}({arg.name}, *common_device);"
445
                    )
446
            elif isinstance(arg.lazy_type, OptionalCType):
447
                assert arg.lazy_type.elem == BaseCType(getValueT()), arg.lazy_type.elem
448
                # TODO(alanwaketan): Maybe we want to apply GetLtcTensorOrCreateForWrappedNumber here, but hold it
449
                # until we encounter a real world example.
450
                lazy_tensor_decls.append(
451
                    f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
452
                    f"{self.backend_namespace}::{self.try_get_tensor}({arg.name}.value_or(at::Tensor()));"
453
                )
454
            else:
455
                raise AssertionError(
456
                    f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
457
                )
458
        return ("\n        ").join(lazy_tensor_decls)
459

460
    def force_eager_fallback(
461
        self,
462
        func: NativeFunction,
463
        schema: LazyIrSchema,
464
        metadata: BackendMetadata,
465
        sig: DispatcherSignature | NativeSignature,
466
    ) -> str:
467
        if self.gen_forced_fallback_code:
468
            return gen_fallback_code(
469
                schema, sig, overload_name=func.func.name.overload_name
470
            )
471
        return ""
472

473
    def metrics(self, func: NativeFunction, schema: LazyIrSchema) -> str:
474
        return f"{self.metrics_counter};"
475

476
    def get_device(self, func: NativeFunction, schema: LazyIrSchema) -> str:
477
        value_args = schema.filtered_args(values=True, scalars=False)
478
        scalar_args = schema.filtered_args(values=False, scalars=True)
479
        value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
480
        optional_device = OptionalCType(BaseCType(deviceT))
481
        optional_devices = [
482
            a.name for a in scalar_args if a.lazy_type == optional_device
483
        ]
484
        assert (
485
            len(value_types_names) > 0 or len(optional_devices) > 0
486
        ), "Expected at least one Value or Device type"
487
        get_device_str = (
488
            f"{self.get_device_fn}({', '.join(value_types_names + optional_devices)})"
489
        )
490
        return f"""auto common_device = {get_device_str};
491
        TORCH_INTERNAL_ASSERT(common_device);
492
        """
493

494
    def shape_inference(self, func: NativeFunction, schema: LazyIrSchema) -> str:
495
        metadata = self.backend_index.get_kernel(func)
496
        assert metadata is not None
497
        all_args = schema.filtered_args()
498
        returns_length = len(schema.returns)
499
        # call the meta kernel if it exists, to compute output shape/dtype for our IR
500
        # Note [Generated LTC Shape Functions]
501
        # LTC uses meta tensors from core to do shape inference when possible, and otherwise
502
        # we generate a shape function declaration that needs to be manually implemented.
503
        # How do we detect which ops are eligible to use meta tensors?
504
        # In general we should be able to use meta tensors not just on structured operators,
505
        # but also on composite operators that are implemented in terms of structured kernels.
506
        # We don't currently have a way of knowing at codegen time which ops are implemented that way.
507
        # This is the case for all view and view_copy operators however, so we're going to
508
        # use them specifically for all of the view_copy ops (instead of manually writing shape rules for all of them).
509
        is_view_copy_op = "view_copy" in func.tags
510
        is_structured = func.structured or func.structured_delegate is not None
511
        if is_structured or is_view_copy_op:
512
            meta_out = """
513
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};"""
514
            if returns_length > 1:
515

516
                def this_shape(i: int) -> str:
517
                    return f"torch::lazy::Shape(std::get<{i}>(out_meta).scalar_type(), std::get<{i}>(out_meta).sizes().vec())"
518

519
                shapes_str = ",".join([this_shape(i) for i in range(returns_length)])
520
                meta_out = "std::vector<torch::lazy::Shape> shapes{" + shapes_str + "};"
521

522
            # Convert tensor args to the meta device and call it.
523
            # (We can't pass in the input tensors directly, because they are "functional wrappers".
524
            # If any of the meta kernels call a tensor op and redispatch, we don't want to hit the functionalize kernels.)
525
            # Even at::meta:: functions might redispatch, e.g. if they call into view ops.
526
            dispatcher_sig = DispatcherSignature.from_schema(func.func)
527
            meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
528
            meta_call_args = [
529
                e.expr
530
                for e in translate(
531
                    meta_call_ctx, dispatcher_sig.arguments(), method=False
532
                )
533
            ]
534
            if is_view_copy_op:
535
                # view_copy ops always have a CompositeExplicitAutogradNonFunctional kernel
536
                assert func.has_composite_explicit_autograd_non_functional_kernel
537
                dispatch_ns = "compositeexplicitautogradnonfunctional"
538
            else:
539
                dispatch_ns = "meta"
540
            aten_name = schema.aten_name
541
            # TODO: this is trolling
542
            if func.func.has_symint() and metadata.supports_symint():
543
                aten_name += "_symint"
544
            shape_str = f"""\
545
        {meta_conversion_str}
546
        auto out_meta = at::{dispatch_ns}::{aten_name}({', '.join(meta_call_args)});
547
        {meta_out}"""
548
        else:
549
            shape_sig = ComputeShapeSignature(
550
                metadata.kernel, func, symint=metadata.supports_symint()
551
            )
552
            shape_str = f"""
553
            auto shapes = {shape_sig.shape_call};"""
554

555
        shape_str += f"""
556
            TORCH_INTERNAL_ASSERT(shapes.size() == {returns_length});"""
557

558
        # Calculating which dimensions are symbolic
559
        func_schema_str = "aten::" + str(func.func)
560
        shape_str += f"""
561
            if(torch::lazy::symbolicShapeEnabled()){{
562
                std::vector<torch::jit::IValue> inputs = {{ {', '.join(str(a.name) for a in all_args)} }};
563
                const char* schema_str = "{func_schema_str}";
564
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
565
            }}
566
        """
567
        return shape_str
568

569
    def build_ir_node(self, func: NativeFunction, schema: LazyIrSchema) -> str:
570
        node_ctor_input_str = node_ctor_inputs(schema)
571
        return f"""torch::lazy::NodePtr node = torch::lazy::ReuseNode<{schema.node_name}>({node_ctor_input_str});
572
        if (!node) {{
573
            {self.shape_inference(func, schema)}
574
            node = torch::lazy::MakeNode<{schema.node_name}>({node_ctor_input_str}, std::move(shapes));
575
            CacheNode(node);
576
        }}
577
        """
578

579
    def create_lazy_tensor(self, first_tensor_name: str | None = None) -> str:
580
        # xla uses an instance method for tensor creation, for the time being
581
        if self.create_from_first_tensor:
582
            # TODO(whc) remove this if XLA switches to using static method for creation
583
            assert (
584
                first_tensor_name is not None
585
            ), "Requires first tensor to create lazy tensor"
586
            return f"{first_tensor_name}.{self.create_tensor}"
587
        return f"{self.backend_namespace}::{self.create_tensor}"
588

589
    def return_aten_tensor(self, func: NativeFunction, schema: LazyIrSchema) -> str:
590
        returns_length = len(schema.returns)
591
        value_args = schema.filtered_args(values=True, scalars=False)
592
        value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
593
        first_tensor_name = value_types_names[0] if len(value_types_names) > 0 else None
594
        bridge_str = f"""auto result = {self.create_aten_from_ltc_tensor}(
595
                {self.create_lazy_tensor(first_tensor_name)}(std::move(node), *common_device));"""
596

597
        if returns_length > 1:
598
            assert (
599
                len(value_types_names) > 0
600
            ), "Code below assumes there is at least one tensor arg"
601
            bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors;
602
        for (int i = 0; i < {returns_length}; i++) {{
603
            lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({getValueT()}(node, i), *common_device));
604
        }}
605
        auto result = {self.tuple_aten_from_ltc_tensors}<{returns_length}>(lazy_tensors);"""
606

607
        if schema.name.name.inplace or func.func.is_out_fn():
608
            assert returns_length == 1, (
609
                "We assumed there was no such case where an op is an in-place variant "
610
                f"and has tuple outputs, but got tuple of len {returns_length}."
611
            )
612
            bridge_str = f"""lazy_{first_tensor_name}->SetInPlaceIrValue(node);
613
        auto& result = {first_tensor_name};"""
614

615
        bridge_str += """
616
        return result;"""
617
        return bridge_str
618

619
    @method_with_native_function
620
    def __call__(self, func: NativeFunction) -> list[str]:
621
        sig = kernel_signature(func, self.backend_index)
622
        metadata = self.backend_index.get_kernel(func)
623
        assert metadata is not None
624
        schema = LazyIrSchema(func.func, symint=metadata.supports_symint())
625
        return [
626
            f"""\
627
    {sig.decl(name=f"{self.class_method_name}::{metadata.kernel}")} {{
628
        {self.force_eager_fallback(func, schema, metadata, sig)}
629
        {self.metrics(func, schema)}
630
        {self.get_device(func, schema)}
631
        {self.lazy_tensor_decls(func, schema)}
632
        {self.build_ir_node(func, schema)}
633
        {self.return_aten_tensor(func, schema)}
634
    }}\n
635
    """
636
        ]
637

638

639
class ComputeShapeSignature:
640
    """
641
    Here we use the base name as the suffix of the signature to avoid generating for in-place variants.
642
    """
643

644
    def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool) -> None:
645
        self.__schema = LazyIrSchema(f.func, symint=symint)
646
        self.__dispatch_args = ", ".join(
647
            [a.decl() for a in dispatcher.arguments(f.func, symint=symint)]
648
        )
649
        self.__call_args = ", ".join(
650
            [f"{arg.name}" for arg in self.__schema.filtered_args(generator=True)]
651
        )
652
        self.__kernel_name = kernel_name
653

654
    def __decl_suffix(self) -> str:
655
        return f"{self.__kernel_name}({self.__dispatch_args})"
656

657
    def __call_suffix(self) -> str:
658
        return f"{self.__kernel_name}({self.__call_args})"
659

660
    @property
661
    def shape_decl(self) -> str:
662
        return f"TORCH_API std::vector<torch::lazy::Shape> compute_shape_{self.__decl_suffix()}"
663

664
    @property
665
    def shape_call(self) -> str:
666
        return f"torch::lazy::compute_shape_{self.__call_suffix()}"
667

668

669
@dataclass(frozen=True)
670
class GenLazyShapeInferenceDefinition:
671
    backend_index: BackendIndex
672
    tensor_class: str
673

674
    @method_with_native_function
675
    def __call__(self, f: NativeFunction) -> list[str]:
676
        metadata = self.backend_index.get_kernel(f)
677
        assert metadata is not None
678

679
        # See Note [Generated LTC Shape Functions]
680
        is_view_copy_op = "view_copy" in f.tags
681
        is_structured = f.structured or f.structured_delegate is not None
682
        if is_structured or is_view_copy_op:
683
            return []
684
        else:
685
            shape_sig = ComputeShapeSignature(
686
                metadata.kernel, f, symint=metadata.supports_symint()
687
            )
688
            return ["\n".join([f"{shape_sig.shape_decl};"])]
689

690

691
def generate_non_native_lazy_ir_nodes(
692
    non_native: list[dict[str, Any]], gen_lazy_ir: GenLazyIR
693
) -> list[str]:
694
    """Generate the non-native lazy IR node classes"""
695
    nodes = []
696
    for op in non_native:
697
        # Set default properties for Non-Native IRs
698
        properties = LazyIrProperties("ShapeCache", "CanBeReused", "LowerDeclOnly")
699
        for p in op.get("properties", []):
700
            setattr(properties, p, True)
701

702
        # non-native is assumed to want symint bindings if you wrote symint
703
        schema = LazyIrSchema(FunctionSchema.parse(op["func"]), properties, symint=True)
704
        schema.opkind = op.get("opkind")
705
        nodes.append(gen_lazy_ir.gen(schema)[0])
706

707
    return nodes
708

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.