pytorch

Форк
0
/
fake_impls.py 
1057 строк · 34.6 Кб
1
# mypy: ignore-errors
2

3
import functools
4
import itertools
5
import math
6
import sys
7
from typing import Callable, Union
8

9
import torch
10
import torch._custom_op
11
import torch._logging
12

13
from torch._ops import OpOverload
14
from torch._prims_common import (
15
    elementwise_dtypes,
16
    ELEMENTWISE_TYPE_PROMOTION_KIND,
17
    is_boolean_dtype,
18
    is_float_dtype,
19
    is_integer_dtype,
20
)
21

22
from torch._subclasses.fake_tensor import (
23
    DataDependentOutputException,
24
    DynamicOutputShapeException,
25
    FakeTensor,
26
    in_kernel_invocation_manager,
27
    run_fallback_kernel,
28
    UnsupportedOperatorException,
29
)
30
from torch.fx.operator_schemas import normalize_function
31

32
from torch.utils._stats import count_label
33

34
pytree = torch.utils._pytree
35

36
__all__ = [
37
    "op_implementations_checks",
38
    "get_fast_op_impls",
39
    "stride_incorrect_op",
40
    "has_meta",
41
]
42

43
op_implementations_dict = {}
44
op_implementations_checks = []
45

46

47
aten = torch._ops.ops.aten
48

49

50
def ordered_set(*items):
51
    return dict.fromkeys(items, True)
52

53

54
# This function indicates if the backend device
55
# supports non-contiguous tensors
56
def is_noncontiguous_supported(device):
57
    if device.type == "hpu":
58
        return False
59
    return True
60

61

62
_like_tensor_constructors = ordered_set(
63
    aten.empty_like.default,
64
    aten.empty_like.out,
65
    aten.full_like.default,
66
    aten.full_like.out,
67
    aten.ones_like.default,
68
    aten.ones_like.out,
69
    aten.rand_like.default,
70
    aten.rand_like.out,
71
    aten.randn_like.default,
72
    aten.randn_like.out,
73
    aten.randint_like.default,
74
    aten.randint_like.out,
75
    aten.randint_like.low_dtype,
76
    aten.randint_like.low_dtype_out,
77
    aten.zeros_like.default,
78
    aten.zeros_like.out,
79
    aten.new_empty.default,
80
    aten.new_empty.out,
81
    aten.new_empty_strided.default,
82
    aten.new_empty_strided.out,
83
    aten.new_full.default,
84
    aten.new_full.out,
85
    aten.new_zeros.default,
86
    aten.new_zeros.out,
87
    aten.new_ones.default,
88
    aten.new_ones.out,
89
)
90

91

92
_device_not_kwarg_ops = ordered_set(
93
    aten._resize_output_.default,
94
    aten._nested_tensor_from_tensor_list.default,
95
    aten._nested_tensor_from_tensor_list.out,
96
    aten.pin_memory.default,
97
    aten.is_pinned.default,
98
    aten.to.device,
99
    aten.to.prim_Device,
100
    aten._pin_memory.default,
101
    aten._pin_memory.out,
102
    aten._resize_output.default,
103
    aten._resize_output.out,
104
)
105

106
# this op is never actually used
107
_non_kwarg_device_constructors = (aten._list_to_tensor,)
108

109

110
def contains_tensor_types(type):
111
    tensor_type = torch._C.TensorType.get()
112
    return type.isSubtypeOf(tensor_type) or any(
113
        contains_tensor_types(e) for e in type.containedTypes()
114
    )
115

116

117
@functools.lru_cache(None)
118
def _is_tensor_constructor(func: OpOverload):
119
    assert isinstance(func, OpOverload)
120
    schema = func._schema
121
    if any(contains_tensor_types(arg.type) for arg in schema.arguments):
122
        return False
123
    # TODO: no real reason to restrict multiple outputs
124
    return (
125
        len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get()
126
    )
127

128

129
def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]):
130
    def impl_decorator(op_impl):
131
        if isinstance(run_impl_check, OpOverload):
132
            assert (
133
                run_impl_check not in op_implementations_dict
134
            ), f"duplicate registration: {run_impl_check}"
135
            op_implementations_dict[run_impl_check] = op_impl
136
        elif isinstance(run_impl_check, (list, tuple)):
137
            for op in run_impl_check:
138
                register_op_impl(op)(op_impl)
139
        else:
140
            assert callable(run_impl_check)
141
            op_implementations_checks.append((run_impl_check, op_impl))
142

143
        return op_impl
144

145
    return impl_decorator
146

147

148
@register_op_impl(op_implementations_dict.__contains__)
149
def dispatch_to_op_implementations_dict(fake_mode, func, *args, **kwargs):
150
    return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
151

152

153
@register_op_impl(_is_tensor_constructor)
154
@register_op_impl([*_like_tensor_constructors])
155
def constructors(fake_mode, func, *args, **kwargs):
156
    assert func not in _non_kwarg_device_constructors
157
    _, new_kwargs = normalize_function(
158
        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
159
    )
160
    if "names" in kwargs:
161
        raise UnsupportedOperatorException(
162
            "torch.compile doesn't support named tensors"
163
        )
164

165
    if func in _like_tensor_constructors:
166
        default_device = new_kwargs["input"].device
167
        # TODO: file issue
168
        args = (new_kwargs.pop("input"),)
169
    else:
170
        # cpu is default device if none is specified
171
        default_device = torch.device("cpu")
172
        args = ()
173
    out_device = new_kwargs.pop("device", None)
174
    out_device = out_device if out_device is not None else default_device
175
    new_kwargs["device"] = torch.device("meta")
176
    # _like constructors have fake tensor inputs (maybe this causes the non-like
177
    # to fail? hmmm)
178
    with in_kernel_invocation_manager(fake_mode):
179
        r = func(*args, **new_kwargs)
180
    return FakeTensor(fake_mode, r, out_device)
181

182

183
@register_op_impl(aten.to.prim_Device)
184
@register_op_impl(aten.to.device)
185
def non_kwarg_to(fake_mode, func, *args, **kwargs):
186
    _, new_kwargs = normalize_function(
187
        func, args, kwargs, normalize_to_only_use_kwargs=True
188
    )
189
    input_device = new_kwargs["device"]
190
    out_device = input_device if input_device else new_kwargs["input"].device
191
    new_kwargs["device"] = torch.device("meta")
192
    inp = new_kwargs.pop("input")
193
    with in_kernel_invocation_manager(fake_mode):
194
        r = func(inp, **new_kwargs)
195
    # TODO: I think this does the wrong thing if r is inp
196
    return fake_mode.fake_tensor_converter.from_meta_and_device(
197
        fake_mode, r, out_device
198
    )
199

200

201
def stride_incorrect_op(op):
202
    if op.namespace not in ("aten", "prims"):
203
        return False
204
    if op is aten._fft_c2c.default:
205
        return False
206

207
    op_name = op.name()
208
    if "fft" in op_name:
209
        return True
210
    return False
211

212

213
# These operators have meta implementations with incorrect strides
214
@register_op_impl(stride_incorrect_op)
215
def wordaround_stride_incorrect_op(fake_mode, func, *args, **kwargs):
216
    # This is a workaround for meta implmentations with incorrect strides
217

218
    def is_symbolic(x):
219
        if isinstance(x, FakeTensor):
220
            return x._has_symbolic_sizes_strides
221
        if isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool)):
222
            return True
223
        return False
224

225
    # For static shapes, we can fall back to eager for the real strides
226
    if fake_mode.allow_fallback_kernels:
227
        require_dynamic = any(
228
            is_symbolic(x) for x in itertools.chain(args, kwargs.values())
229
        )
230
        if not require_dynamic:
231
            flat_args, args_spec = pytree.tree_flatten((args, kwargs))
232
            return run_fallback_kernel(fake_mode, func, flat_args, args_spec, None)
233

234
    raise UnsupportedOperatorException(func)
235

236

237
# Dont default to default device handling,
238
# since the device of `the_template` is ignored
239
@register_op_impl(aten.resize_as_.default)
240
def resize_as_(fake_mode, func, *args, **kwargs):
241
    with in_kernel_invocation_manager(fake_mode):
242
        return func(*args, **kwargs)
243

244

245
@register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default)
246
def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs):
247
    # TODO: remove me
248
    return constructors(fake_mode, func, *args, **kwargs)
249

250

251
# index.Tensor data-dependent in only some conditions
252
@register_op_impl(
253
    lambda func: torch.Tag.dynamic_output_shape in func.tags
254
    and func
255
    not in [aten.index.Tensor, aten.nonzero.default, aten.repeat_interleave.Tensor]
256
)
257
def dyn_shape(fake_mode, func, *args, **kwargs):
258
    raise DynamicOutputShapeException(func)
259

260

261
@register_op_impl(aten.repeat_interleave.Tensor)
262
def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None):
263
    if output_size is None:
264
        if (
265
            fake_mode.shape_env is None
266
            or not fake_mode.shape_env.allow_dynamic_output_shape_ops
267
        ):
268
            raise DynamicOutputShapeException(func)
269

270
        output_size = fake_mode.shape_env.create_unbacked_symint()
271

272
        # Avoid importing sympy at a module level
273
        from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
274

275
        _constrain_range_for_size(output_size)
276
        # TODO: consider a memo
277
    return repeats.new_empty(output_size)
278

279

280
@register_op_impl(torch.ops.aten._local_scalar_dense.default)
281
def local_scalar_dense(fake_mode, func, arg):
282
    if fake_mode.shape_env is None or not fake_mode.shape_env.allow_scalar_outputs:
283
        # Without symints/symfloats, cannot handle this
284
        raise DataDependentOutputException(func)
285
    if is_float_dtype(arg.dtype):
286
        return fake_mode.shape_env.create_unbacked_symfloat()
287
    elif is_integer_dtype(arg.dtype):
288
        return fake_mode.shape_env.create_unbacked_symint()
289
    elif is_boolean_dtype(arg.dtype):
290
        return fake_mode.shape_env.create_unbacked_symbool()
291
    else:
292
        raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}")
293

294

295
@register_op_impl(torch.ops.aten.nonzero.default)
296
def nonzero(fake_mode, func, arg):
297
    if (
298
        fake_mode.shape_env is None
299
        or not fake_mode.shape_env.allow_dynamic_output_shape_ops
300
    ):
301
        # Without symints/symfloats, cannot handle this
302
        raise DynamicOutputShapeException(func)
303

304
    if arg.nonzero_memo is None:
305
        nnz = fake_mode.shape_env.create_unbacked_symint()
306

307
        # This is unsound, but it works well in practice
308
        # See https://docs.google.com/document/d/1lFRYAJo5nrfxRhwIzGnfi2pbLpU6T4ytSRSuLJ5qebI/edit#
309
        # TODO: Add a config knob to turn off this unsound behavior
310
        #
311
        # NB: If numel < 2, the bounds here might be COMPLETELY
312
        # disjoint with what can actually occur.  But this is fine:
313
        # remember, the hypothesis is that if your later code works
314
        # with N >= 2, it will work with N = 1 and N = 0.
315
        maxval = sys.maxsize - 1
316

317
        # Avoid importing sympy at a module level
318
        from torch.fx.experimental.symbolic_shapes import (
319
            _constrain_range_for_size,
320
            has_free_symbols,
321
        )
322

323
        if not has_free_symbols(arg.numel()):
324
            # Don't upgrade the range if numel is less than two, since we then
325
            # have an empty range which makes things go explodey.  We also
326
            # don't allow for 2 because that would specialize the unbacked
327
            # SymInt to 2, which is also likely to be buggy.
328
            if arg.numel() > 2:
329
                maxval = int(arg.numel())
330

331
        _constrain_range_for_size(nnz, max=maxval)
332

333
        arg._nonzero_memo = nnz
334
        arg._nonzero_memo_vc = arg._version
335

336
    return arg.new_empty((arg.nonzero_memo, arg.dim()), dtype=torch.int64)
337

338

339
@register_op_impl(torch.ops.aten.masked_select.default)
340
def masked_select(fake_mode, func, self, mask):
341
    if (
342
        fake_mode.shape_env is None
343
        or not fake_mode.shape_env.allow_dynamic_output_shape_ops
344
    ):
345
        # Without symints/symfloats, cannot handle this
346
        raise DynamicOutputShapeException(func)
347

348
    nnz = fake_mode.shape_env.create_unbacked_symint()
349

350
    # see nonzero for commentary
351
    maxval = sys.maxsize - 1
352

353
    # Avoid importing sympy at a module level
354
    from torch.fx.experimental.symbolic_shapes import (
355
        _constrain_range_for_size,
356
        has_free_symbols,
357
    )
358

359
    if not has_free_symbols(self.numel()):
360
        if self.numel() > 2:
361
            maxval = int(self.numel())
362

363
    _constrain_range_for_size(nnz, max=maxval)
364

365
    return self.new_empty((nnz,))
366

367

368
# NB: this must be ordered after local_scalar_dense
369
@register_op_impl(lambda func: torch.Tag.data_dependent_output in func.tags)
370
def data_dep(fake_mode, func, *args, **kwargs):
371
    raise DataDependentOutputException(func)
372

373

374
# Bool Indices get Expanded as Masks
375
# See: IndexingUtils.h:expandTensors
376
def check_no_bool_index_tensors(func, self, indices):
377
    for index in indices:
378
        if index is not None and index.dtype in (torch.bool, torch.uint8):
379
            raise DynamicOutputShapeException(func)
380

381

382
def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs):
383
    _, new_kwargs = normalize_function(
384
        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
385
    )
386

387
    out_device = new_kwargs["input"].device
388
    with in_kernel_invocation_manager(fake_mode):
389
        out = func(*args, **kwargs)
390
        if not is_noncontiguous_supported(out_device):
391
            out = out.new_empty(out.shape)
392

393
    if out is new_kwargs["input"]:
394
        return out  # copy_
395
    return FakeTensor(fake_mode, out, out_device)
396

397

398
_is_builtin_namespaces = ordered_set("aten", "prims", "prim")
399

400

401
def is_builtin(op):
402
    return op.namespace in _is_builtin_namespaces
403

404

405
def has_meta(func):
406
    return torch._C._dispatch_has_computed_kernel_for_dispatch_key(func.name(), "Meta")
407

408

409
@register_op_impl(
410
    lambda func: is_builtin(func) and "foreach" in func.name() and has_meta(func)
411
)
412
def foreach_run_and_map_input_device(fake_mode, func, *args, **kwargs):
413
    tensor_lists = []
414
    for arg in itertools.chain(args, kwargs.values()):
415
        if (
416
            isinstance(arg, (list, tuple))
417
            and len(arg)
418
            and isinstance(arg[0], torch.Tensor)
419
        ):
420
            tensor_lists.append(arg)
421

422
    try:
423
        with in_kernel_invocation_manager(fake_mode):
424
            out_meta = func(*args, **kwargs)
425
    except NotImplementedError as not_implemented_error:
426
        return NotImplemented
427

428
    if not out_meta:
429
        return out_meta
430

431
    assert tensor_lists
432
    out_fake = []
433

434
    for i, meta_t in enumerate(out_meta):
435
        device, _ = FakeTensor._find_common_device(func, [tl[i] for tl in tensor_lists])
436
        out_fake.append(
437
            fake_mode.fake_tensor_converter.from_meta_and_device(
438
                fake_mode, meta_t, device
439
            )
440
        )
441

442
    return out_fake
443

444

445
# Dont default to default device handling,
446
# Since op can take in non-zero sized cpu
447
# index tensors with cuda self
448
@register_op_impl(aten.index.Tensor)
449
def index_tensor(fake_mode, func, *args, **kwargs):
450
    from torch._meta_registrations import meta_index_Tensor
451

452
    _, new_kwargs = normalize_function(
453
        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
454
    )
455

456
    out_device = new_kwargs["input"].device
457
    # ensure nonzero call goes to fake tensor
458
    with fake_mode:
459
        out = meta_index_Tensor(*args, **kwargs)
460
        return out.to(out_device)
461

462

463
# Can take mixed meta/non-meta arguments; the meta registration
464
# will roughly do the right thing even when given real devices
465
@register_op_impl(aten._embedding_bag.default)
466
def embedding_bag(fake_mode, func, *args, **kwargs):
467
    from torch._meta_registrations import meta_embedding_bag
468

469
    with fake_mode:
470
        return meta_embedding_bag(*args, **kwargs)
471

472

473
# takes in multiple-devices, dont default to default device handling
474
@register_op_impl(aten._unsafe_index_put.default)
475
@register_op_impl(aten.copy.default)
476
@register_op_impl(aten.copy_.default)
477
@register_op_impl(aten.slice_scatter.default)
478
def multi_device_op_default(fake_mode, func, *args, **kwargs):
479
    return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
480

481

482
# same with multi_device_op_default, but return the input
483
@register_op_impl(aten.copy.out)
484
@register_op_impl(aten.slice_scatter.out)
485
def multi_device_op_out(fake_mode, func, *args, **kwargs):
486
    with in_kernel_invocation_manager(fake_mode):
487
        out = func(*args, **kwargs)
488

489
    _, new_kwargs = normalize_function(
490
        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
491
    )
492

493
    return new_kwargs["input"]
494

495

496
@register_op_impl(aten.index_put.default)
497
@register_op_impl(aten.index_put_.default)
498
def index_put_impl(fake_mode, func, *args, **kwargs):
499
    _, new_kwargs = normalize_function(
500
        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
501
    )
502

503
    values = new_kwargs["values"]
504
    self_device = new_kwargs["input"].fake_device
505
    torch._check(
506
        self_device == values.fake_device or (values.ndim == 0 and values.numel() == 1),
507
        lambda: f"Mismatching {func} device between self ({self_device}) and values ({values.device})",
508
    )
509

510
    out = run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
511
    if func is aten.index_put_.default:
512
        return new_kwargs["input"]
513
    else:
514
        return out
515

516

517
@register_op_impl(aten._nested_tensor_from_tensor_list.default)
518
@register_op_impl(aten._nested_tensor_from_tensor_list.out)
519
def nested_tensors_unsupported(fake_mode, func, *args, **kwargs):
520
    raise UnsupportedOperatorException(
521
        "torch.compile does not support strided NestedTensor"
522
    )
523

524

525
@register_op_impl(
526
    [
527
        x
528
        for x in _device_not_kwarg_ops
529
        if x
530
        not in (
531
            # these are already registered elsewhere
532
            aten.to.device,
533
            aten.to.prim_Device,
534
            aten._nested_tensor_from_tensor_list.default,
535
            aten._nested_tensor_from_tensor_list.out,
536
        )
537
    ]
538
)
539
def nyi(fake_mode, func, *args, **kwargs):
540
    assert func not in _device_not_kwarg_ops, f"NYI: {func}"
541

542

543
@register_op_impl([aten.convolution.default, aten.convolution_backward.default])
544
def conv(fake_mode, func, *args, **kwargs):
545
    _, kwargs = normalize_function(
546
        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
547
    )
548
    device = kwargs["input"].fake_device
549
    # need to re-enable mode so the tensors report fake device
550
    with fake_mode:
551
        # if the input is unsqueezed is done in Convolution.cpp we get segfault
552
        k = kwargs["weight"].ndim
553
        batch = kwargs["input"].shape[0]
554

555
        # Avoid importing sympy at a module level
556
        from torch.fx.experimental.symbolic_shapes import has_hint
557

558
        if not has_hint(batch):
559
            # TODO: We can make this a little more faithful with best effort
560
            # channels last detection (but only if it's statically obvious!)
561
            mem_fmt = None
562
        elif k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu:
563
            mem_fmt = None
564
        else:
565
            if func is aten.convolution.default:
566
                conv_backend = torch._C._select_conv_backend(**kwargs)
567
            else:
568
                conv_backend = torch._C._select_conv_backend(
569
                    kwargs["input"],
570
                    kwargs["weight"],
571
                    bias=None,
572
                    stride=kwargs["stride"],
573
                    padding=kwargs["padding"],
574
                    dilation=kwargs["dilation"],
575
                    transposed=kwargs["transposed"],
576
                    output_padding=kwargs["output_padding"],
577
                    groups=kwargs["groups"],
578
                    bias_sizes=kwargs["bias_sizes"],
579
                )
580
            mem_fmt = torch._C._conv_determine_backend_memory_format(
581
                kwargs["input"], kwargs["weight"], conv_backend
582
            )
583

584
    def convert(t, mem_fmt):
585
        if t is None:
586
            return t
587
        if mem_fmt is not None:
588
            t = t.to(memory_format=mem_fmt)
589
        return FakeTensor(fake_mode, t, device)
590

591
    with in_kernel_invocation_manager(fake_mode):
592
        out = func(**kwargs)
593

594
        if func is aten.convolution.default:
595
            return convert(out, mem_fmt)
596
        else:
597
            return (
598
                convert(out[0], mem_fmt),
599
                convert(out[1], mem_fmt),
600
                convert(out[2], None),
601
            )
602

603

604
@register_op_impl(aten._scaled_dot_product_flash_attention.default)
605
def meta__scaled_dot_product_flash(fake_mode, func, *args, **kwargs):
606
    _, kwargs = normalize_function(
607
        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
608
    )
609

610
    query = kwargs["query"]
611
    key = kwargs["key"]
612
    return_debug_mask = kwargs["return_debug_mask"]
613
    # unused: value, dropout_p, is_causal, scale
614

615
    def convert_tensor(t, device):
616
        return FakeTensor(fake_mode, t, device)
617

618
    batch_size = query.size(0)
619
    num_heads = query.size(1)
620
    max_seqlen_batch_q = query.size(2)
621
    head_dim = query.size(3)
622
    max_seqlen_batch_k = key.size(2)
623

624
    query_t = query.transpose(1, 2)
625
    # empty_like already returns a fake tensor so we don't need to convert it
626
    attention = torch.empty_like(query_t).transpose(1, 2)
627
    logsumexp = convert_tensor(
628
        torch.empty(
629
            (batch_size, num_heads, max_seqlen_batch_q),
630
            dtype=torch.float,
631
            device="meta",
632
        ),
633
        device=query.device,
634
    )
635

636
    if return_debug_mask:
637
        blocksize_c = 128 if head_dim > 64 else 256
638
        max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
639
        if max_seqlen_batch_k <= 128:
640
            max_seqlen_k = 128
641
        elif max_seqlen_batch_k <= 256:
642
            max_seqlen_k = 256
643
        debug_mask = convert_tensor(
644
            torch.empty(
645
                (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
646
                dtype=query.dtype,
647
                device="meta",
648
            ),
649
            device=query.device,
650
        )
651
    else:
652
        debug_mask = convert_tensor(
653
            torch.empty(0, dtype=query.dtype, device="meta"),
654
            query.device,
655
        )
656

657
    # Note [Seed and Offset]: device for seed and offset below depends on whether we are
658
    # capturing or not, but at the time of tracing we don't know if we
659
    # are going to use cudagraphs or not, so we return meta tensors here
660
    # it's possible we'll need to have some special handling in inductor for sdpa
661

662
    return (
663
        attention,
664
        logsumexp,
665
        None,
666
        None,
667
        max_seqlen_batch_q,
668
        max_seqlen_batch_k,
669
        convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
670
        convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
671
        debug_mask,
672
    )
673

674

675
@register_op_impl(aten._scaled_dot_product_efficient_attention.default)
676
def meta__scaled_dot_product_efficient(fake_mode, func, *args, **kwargs):
677
    _, kwargs = normalize_function(
678
        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
679
    )
680

681
    query = kwargs["query"]
682
    key = kwargs["key"]
683
    value = kwargs["value"]
684
    compute_log_sumexp = kwargs["compute_log_sumexp"]
685
    # unused: attn_bias, dropout_p, is_causal, scale
686

687
    def convert_tensor(t, device):
688
        return FakeTensor(fake_mode, t, device)
689

690
    query = query.transpose(1, 2)
691
    key = key.transpose(1, 2)
692
    value = value.transpose(1, 2)
693

694
    B = query.size(0)
695
    M = query.size(1)
696
    N = key.size(1)
697
    num_heads = query.size(-2)
698
    K = query.size(-1)
699
    Kv = value.size(-1)
700

701
    res = convert_tensor(
702
        torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"),
703
        query.device,
704
    )
705

706
    logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
707
    logsum_exp = convert_tensor(
708
        torch.empty(
709
            (B, num_heads, logsumexp_dim),
710
            dtype=torch.float,
711
            device="meta",
712
        ),
713
        query.device,
714
    )
715

716
    res = res.transpose(1, 2)
717

718
    # See Note [Seed and Offset]:
719
    seed = convert_tensor(
720
        torch.empty((), dtype=torch.long, device="meta"), query.device
721
    )
722
    offset = convert_tensor(
723
        torch.empty((), dtype=torch.long, device="meta"), query.device
724
    )
725

726
    return res, logsum_exp, seed, offset
727

728

729
@register_op_impl(aten._flash_attention_forward.default)
730
def meta__flash_attention_forward(fake_mode, func, *args, **kwargs):
731
    _, kwargs = normalize_function(
732
        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
733
    )
734

735
    query = kwargs["query"]
736
    key = kwargs["key"]
737
    cum_seq_q = kwargs["cum_seq_q"]
738
    cum_seq_k = kwargs["cum_seq_k"]
739
    max_q = kwargs["max_q"]
740
    max_k = kwargs["max_k"]
741
    return_debug_mask = kwargs["return_debug_mask"]
742
    # unused: value, dropout_p, is_causal, scale
743

744
    def convert_tensor(t, device):
745
        return FakeTensor(fake_mode, t, device)
746

747
    # NB: there are two underlying paths:
748
    # 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim)
749
    # 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total
750
    #    includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total
751
    batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1
752
    max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q
753
    max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k
754
    num_heads = query.size(-2)
755
    head_dim = query.size(-1)
756

757
    # Cuda Path
758
    # note: empty_like already returns a fake tensor, we don't need to wrap it
759
    attention = torch.empty_like(query)
760
    logsumexp = convert_tensor(
761
        torch.empty(
762
            (batch_size, num_heads, max_seqlen_batch_q),
763
            dtype=torch.float,
764
            device="meta",
765
        ),
766
        device=query.device,
767
    )
768

769
    if return_debug_mask:
770
        blocksize_c = 128 if head_dim > 64 else 256
771
        max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
772
        if max_seqlen_batch_k <= 128:
773
            max_seqlen_k = 128
774
        elif max_seqlen_batch_k <= 256:
775
            max_seqlen_k = 256
776
        debug_mask = convert_tensor(
777
            torch.empty(
778
                (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
779
                dtype=query.dtype,
780
                device="meta",
781
            ),
782
            query.device,
783
        )
784
    else:
785
        debug_mask = convert_tensor(
786
            torch.empty(0, dtype=query.dtype, device="meta"),
787
            query.device,
788
        )
789

790
    # See Note [Seed and Offset]:
791
    return (
792
        attention,
793
        logsumexp,
794
        convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
795
        convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),
796
        debug_mask,
797
    )
798

799

800
@register_op_impl(aten._efficient_attention_forward.default)
801
def meta__efficient_attention_forward(fake_mode, func, *args, **kwargs):
802
    _, kwargs = normalize_function(
803
        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
804
    )
805

806
    query = kwargs["query"]
807
    key = kwargs["key"]
808
    value = kwargs["value"]
809
    cu_seqlens_q = kwargs["cu_seqlens_q"]
810
    max_seqlen_q = kwargs["max_seqlen_q"]
811
    max_seqlen_k = kwargs["max_seqlen_k"]
812
    compute_log_sumexp = kwargs["compute_log_sumexp"]
813
    # unused: bias, cu_seqlens_k, dropout_p, custom_mask_type, scale, causal_diagonal, seqlen_k
814

815
    def convert_tensor(t, device):
816
        return FakeTensor(fake_mode, t, device)
817

818
    B = query.size(0)
819
    M = query.size(1)
820
    N = key.size(1)
821
    num_heads = query.size(-2)
822
    K = query.size(-1)
823
    Kv = value.size(-1)
824

825
    res = convert_tensor(
826
        torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"),
827
        query.device,
828
    )
829

830
    logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B
831
    actual_max_seqlen_q = M
832
    if cu_seqlens_q is not None:
833
        assert max_seqlen_q is not None
834
        actual_max_seqlen_q = max_seqlen_q
835
    actual_max_seqlen_k = max_seqlen_k if max_seqlen_k is not None else N
836
    logsumexp_dim = (
837
        math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0
838
    )
839
    logsum_exp = convert_tensor(
840
        torch.empty(
841
            (logsumexp_batch_dim, num_heads, logsumexp_dim),
842
            dtype=torch.float,
843
            device="meta",
844
        ),
845
        query.device,
846
    )
847

848
    # See Note [Seed and Offset]:
849
    seed = convert_tensor(
850
        torch.empty((), dtype=torch.long, device="meta"), query.device
851
    )
852
    offset = convert_tensor(
853
        torch.empty((), dtype=torch.long, device="meta"), query.device
854
    )
855

856
    return res, logsum_exp, seed, offset, actual_max_seqlen_q, actual_max_seqlen_k
857

858

859
FAST_OP_IMPLEMENTATIONS = {}
860

861

862
# Unlike register_op_impl, these don't do the slow iteration for
863
# run_impl_check, and these run BEFORE decompositions
864
def register_fast_op_impl(func: OpOverload):
865
    def impl_decorator(op_impl):
866
        FAST_OP_IMPLEMENTATIONS[func] = op_impl
867
        return op_impl
868

869
    return impl_decorator
870

871

872
# infer_size_impl in ExpandUtils
873
def infer_size(a, b):
874
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
875

876
    dimsA = len(a)
877
    dimsB = len(b)
878
    ndim = max(dimsA, dimsB)
879
    expandedSizes = [0] * ndim
880
    for i in range(ndim - 1, -1, -1):
881
        offset = ndim - 1 - i
882
        dimA = dimsA - 1 - offset
883
        dimB = dimsB - 1 - offset
884
        sizeA = a[dimA] if dimA >= 0 else 1
885
        sizeB = b[dimB] if dimB >= 0 else 1
886

887
        # NB: It is very important to test for broadcasting, before testing
888
        # sizeA == sizeB.  This is because the broadcasting tests are likely
889
        # to be statically known (in particular, if sizeA/sizeB is unbacked
890
        # but size-like, we will unsoundly assume they never equal 1), but
891
        # the sizeA == sizeB test may not be statically known.  However, once
892
        # we have established that no broadcasting is happening, the
893
        # sizeA == sizeB is now expect_true and we can defer it as a runtime
894
        # assert (this works because Python will return the terminal
895
        # expression of an or statement as-is, without bool()'ing it; if this
896
        # were not the case, we'd need to write this using torch.sym_or() or
897
        # something like that).
898
        torch._check(
899
            guard_size_oblivious(sizeA == 1)
900
            or guard_size_oblivious(sizeB == 1)
901
            or sizeA == sizeB,
902
            lambda: f"The size of tensor a ({sizeA}) "
903
            f"must match the size of tensor b ({sizeB}) "
904
            f"at non-singleton dimension {i})",
905
        )
906
        expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA
907
    return tuple(expandedSizes)
908

909

910
def make_fast_binary_impl(slow_ref):
911
    def fast_binary_impl(mode, *args, **kwargs):
912
        def slow(msg):
913
            count_label(f"slow {msg}")
914
            with mode:
915
                return slow_ref(*args, **kwargs)
916

917
        count_label("attempt fast")
918

919
        # Fast path (based off of TensorIterator fast path).
920
        # Unfortunately, there is no way to easily deduplicate
921
        # this with either the TensorIterator C++ implementation
922
        # (which we don't want to SymIntify, and also the algorithm
923
        # here is slightly different from TensorIterator to allow
924
        # for broadcasting), nor the PrimTorch implementation
925
        # (which does not actually implement a fast path.)
926

927
        operands = args
928

929
        # compute_shape
930
        has_scalars = False
931
        has_tensors = False
932
        final_shape = None
933
        for op in operands:
934
            shape = op.shape if isinstance(op, torch.Tensor) else ()
935
            if len(shape) == 0:
936
                has_scalars = True
937
            else:
938
                has_tensors = True
939
            if final_shape is None:
940
                final_shape = shape
941
            # TODO: Minor optimization: track if the shapes
942
            # were equal so you can skip the equality check
943
            # below if unnecessary
944
            final_shape = infer_size(final_shape, shape)
945
        assert final_shape is not None
946

947
        # Do some extra safety checks to see if the output
948
        # stride is obvious
949
        for op in operands:
950
            if isinstance(op, torch.Tensor) and op.shape == final_shape:
951
                break
952
        else:
953
            return slow("both tensors nontrivially broadcast")
954

955
        # compute_types
956
        cpu = torch.device("cpu")
957
        common_device = cpu
958
        common_dtype = None
959
        output_dtype = None
960
        has_different_input_dtypes = False
961
        for op in operands:
962
            if not isinstance(op, torch.Tensor):
963
                # Use elementwise_dtypes for the tricky case
964
                has_different_input_dtypes = True
965
                continue
966
            if common_device == cpu and not op.device.type == "cpu":
967
                common_device = op.device
968
            # Slightly simplified here as target_dtype cannot vary
969
            if common_dtype is None:
970
                common_dtype = op.dtype
971
            elif common_dtype != op.dtype:
972
                has_different_input_dtypes = True
973

974
        if has_different_input_dtypes:
975
            # compute promotion
976
            # TODO: we don't need the compute type
977
            _, common_dtype = elementwise_dtypes(
978
                *operands, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
979
            )
980

981
        # check all tensors on same device
982
        # cpu scalars are assumed allow
983
        current_cpu_scalars_on_non_cpu = 0
984
        max_cpu_scalars_on_non_cpu = 1  # hard coded atm
985
        for op in operands:
986
            if not isinstance(op, torch.Tensor):
987
                continue
988
            if common_device != cpu and op.dim() == 0 and op.device == cpu:
989
                if current_cpu_scalars_on_non_cpu >= max_cpu_scalars_on_non_cpu:
990
                    return slow("error")
991
                current_cpu_scalars_on_non_cpu += 1
992
            elif op.device != common_device:
993
                return slow("error")
994

995
        # compute_fast_setup_type
996
        is_contiguous = True
997
        is_channels_last = True
998
        # TODO: is_non-overlapping_and_dense (not bound from Python
999
        # no inplace, no out, everything defined
1000

1001
        if is_noncontiguous_supported(common_device):
1002
            for op in operands:
1003
                if not isinstance(op, torch.Tensor):
1004
                    continue
1005
                is_contiguous = is_contiguous and op.is_contiguous(
1006
                    memory_format=torch.contiguous_format
1007
                )
1008
                is_channels_last = is_channels_last and op.is_contiguous(
1009
                    memory_format=torch.channels_last
1010
                )
1011
        if is_contiguous:
1012
            # do contiguous
1013
            count_label("fast is_contiguous")
1014
            return FakeTensor(
1015
                mode,
1016
                torch.empty(
1017
                    final_shape,
1018
                    dtype=common_dtype,
1019
                    device="meta",
1020
                    memory_format=torch.contiguous_format,
1021
                ),
1022
                device=common_device,
1023
            )
1024
        if is_channels_last:
1025
            count_label("fast channels_last")
1026
            # do channels last
1027
            return FakeTensor(
1028
                mode,
1029
                torch.empty(
1030
                    final_shape,
1031
                    dtype=common_dtype,
1032
                    device="meta",
1033
                    memory_format=torch.channels_last,
1034
                ),
1035
                device=common_device,
1036
            )
1037

1038
        return slow("no contiguity match")
1039

1040
    return fast_binary_impl
1041

1042

1043
@functools.lru_cache(None)
1044
def get_fast_op_impls():
1045
    import torch._refs
1046

1047
    register_fast_op_impl(torch.ops.aten.add.Tensor)(
1048
        make_fast_binary_impl(torch._refs.add)
1049
    )
1050
    register_fast_op_impl(torch.ops.aten.sub.Tensor)(
1051
        make_fast_binary_impl(torch._refs.sub)
1052
    )
1053
    register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul))  # type: ignore[has-type]
1054
    register_fast_op_impl(torch.ops.aten.div.Tensor)(
1055
        make_fast_binary_impl(torch._refs.div)
1056
    )
1057
    return FAST_OP_IMPLEMENTATIONS
1058

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.