pytorch
1057 строк · 34.6 Кб
1# mypy: ignore-errors
2
3import functools4import itertools5import math6import sys7from typing import Callable, Union8
9import torch10import torch._custom_op11import torch._logging12
13from torch._ops import OpOverload14from torch._prims_common import (15elementwise_dtypes,16ELEMENTWISE_TYPE_PROMOTION_KIND,17is_boolean_dtype,18is_float_dtype,19is_integer_dtype,20)
21
22from torch._subclasses.fake_tensor import (23DataDependentOutputException,24DynamicOutputShapeException,25FakeTensor,26in_kernel_invocation_manager,27run_fallback_kernel,28UnsupportedOperatorException,29)
30from torch.fx.operator_schemas import normalize_function31
32from torch.utils._stats import count_label33
34pytree = torch.utils._pytree35
36__all__ = [37"op_implementations_checks",38"get_fast_op_impls",39"stride_incorrect_op",40"has_meta",41]
42
43op_implementations_dict = {}44op_implementations_checks = []45
46
47aten = torch._ops.ops.aten48
49
50def ordered_set(*items):51return dict.fromkeys(items, True)52
53
54# This function indicates if the backend device
55# supports non-contiguous tensors
56def is_noncontiguous_supported(device):57if device.type == "hpu":58return False59return True60
61
62_like_tensor_constructors = ordered_set(63aten.empty_like.default,64aten.empty_like.out,65aten.full_like.default,66aten.full_like.out,67aten.ones_like.default,68aten.ones_like.out,69aten.rand_like.default,70aten.rand_like.out,71aten.randn_like.default,72aten.randn_like.out,73aten.randint_like.default,74aten.randint_like.out,75aten.randint_like.low_dtype,76aten.randint_like.low_dtype_out,77aten.zeros_like.default,78aten.zeros_like.out,79aten.new_empty.default,80aten.new_empty.out,81aten.new_empty_strided.default,82aten.new_empty_strided.out,83aten.new_full.default,84aten.new_full.out,85aten.new_zeros.default,86aten.new_zeros.out,87aten.new_ones.default,88aten.new_ones.out,89)
90
91
92_device_not_kwarg_ops = ordered_set(93aten._resize_output_.default,94aten._nested_tensor_from_tensor_list.default,95aten._nested_tensor_from_tensor_list.out,96aten.pin_memory.default,97aten.is_pinned.default,98aten.to.device,99aten.to.prim_Device,100aten._pin_memory.default,101aten._pin_memory.out,102aten._resize_output.default,103aten._resize_output.out,104)
105
106# this op is never actually used
107_non_kwarg_device_constructors = (aten._list_to_tensor,)108
109
110def contains_tensor_types(type):111tensor_type = torch._C.TensorType.get()112return type.isSubtypeOf(tensor_type) or any(113contains_tensor_types(e) for e in type.containedTypes()114)115
116
117@functools.lru_cache(None)118def _is_tensor_constructor(func: OpOverload):119assert isinstance(func, OpOverload)120schema = func._schema121if any(contains_tensor_types(arg.type) for arg in schema.arguments):122return False123# TODO: no real reason to restrict multiple outputs124return (125len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get()126)127
128
129def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]):130def impl_decorator(op_impl):131if isinstance(run_impl_check, OpOverload):132assert (133run_impl_check not in op_implementations_dict134), f"duplicate registration: {run_impl_check}"135op_implementations_dict[run_impl_check] = op_impl136elif isinstance(run_impl_check, (list, tuple)):137for op in run_impl_check:138register_op_impl(op)(op_impl)139else:140assert callable(run_impl_check)141op_implementations_checks.append((run_impl_check, op_impl))142
143return op_impl144
145return impl_decorator146
147
148@register_op_impl(op_implementations_dict.__contains__)149def dispatch_to_op_implementations_dict(fake_mode, func, *args, **kwargs):150return op_implementations_dict[func](fake_mode, func, *args, **kwargs)151
152
153@register_op_impl(_is_tensor_constructor)154@register_op_impl([*_like_tensor_constructors])155def constructors(fake_mode, func, *args, **kwargs):156assert func not in _non_kwarg_device_constructors157_, new_kwargs = normalize_function(158func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True159)160if "names" in kwargs:161raise UnsupportedOperatorException(162"torch.compile doesn't support named tensors"163)164
165if func in _like_tensor_constructors:166default_device = new_kwargs["input"].device167# TODO: file issue168args = (new_kwargs.pop("input"),)169else:170# cpu is default device if none is specified171default_device = torch.device("cpu")172args = ()173out_device = new_kwargs.pop("device", None)174out_device = out_device if out_device is not None else default_device175new_kwargs["device"] = torch.device("meta")176# _like constructors have fake tensor inputs (maybe this causes the non-like177# to fail? hmmm)178with in_kernel_invocation_manager(fake_mode):179r = func(*args, **new_kwargs)180return FakeTensor(fake_mode, r, out_device)181
182
183@register_op_impl(aten.to.prim_Device)184@register_op_impl(aten.to.device)185def non_kwarg_to(fake_mode, func, *args, **kwargs):186_, new_kwargs = normalize_function(187func, args, kwargs, normalize_to_only_use_kwargs=True188)189input_device = new_kwargs["device"]190out_device = input_device if input_device else new_kwargs["input"].device191new_kwargs["device"] = torch.device("meta")192inp = new_kwargs.pop("input")193with in_kernel_invocation_manager(fake_mode):194r = func(inp, **new_kwargs)195# TODO: I think this does the wrong thing if r is inp196return fake_mode.fake_tensor_converter.from_meta_and_device(197fake_mode, r, out_device198)199
200
201def stride_incorrect_op(op):202if op.namespace not in ("aten", "prims"):203return False204if op is aten._fft_c2c.default:205return False206
207op_name = op.name()208if "fft" in op_name:209return True210return False211
212
213# These operators have meta implementations with incorrect strides
214@register_op_impl(stride_incorrect_op)215def wordaround_stride_incorrect_op(fake_mode, func, *args, **kwargs):216# This is a workaround for meta implmentations with incorrect strides217
218def is_symbolic(x):219if isinstance(x, FakeTensor):220return x._has_symbolic_sizes_strides221if isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool)):222return True223return False224
225# For static shapes, we can fall back to eager for the real strides226if fake_mode.allow_fallback_kernels:227require_dynamic = any(228is_symbolic(x) for x in itertools.chain(args, kwargs.values())229)230if not require_dynamic:231flat_args, args_spec = pytree.tree_flatten((args, kwargs))232return run_fallback_kernel(fake_mode, func, flat_args, args_spec, None)233
234raise UnsupportedOperatorException(func)235
236
237# Dont default to default device handling,
238# since the device of `the_template` is ignored
239@register_op_impl(aten.resize_as_.default)240def resize_as_(fake_mode, func, *args, **kwargs):241with in_kernel_invocation_manager(fake_mode):242return func(*args, **kwargs)243
244
245@register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default)246def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs):247# TODO: remove me248return constructors(fake_mode, func, *args, **kwargs)249
250
251# index.Tensor data-dependent in only some conditions
252@register_op_impl(253lambda func: torch.Tag.dynamic_output_shape in func.tags254and func255not in [aten.index.Tensor, aten.nonzero.default, aten.repeat_interleave.Tensor]256)
257def dyn_shape(fake_mode, func, *args, **kwargs):258raise DynamicOutputShapeException(func)259
260
261@register_op_impl(aten.repeat_interleave.Tensor)262def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None):263if output_size is None:264if (265fake_mode.shape_env is None266or not fake_mode.shape_env.allow_dynamic_output_shape_ops267):268raise DynamicOutputShapeException(func)269
270output_size = fake_mode.shape_env.create_unbacked_symint()271
272# Avoid importing sympy at a module level273from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size274
275_constrain_range_for_size(output_size)276# TODO: consider a memo277return repeats.new_empty(output_size)278
279
280@register_op_impl(torch.ops.aten._local_scalar_dense.default)281def local_scalar_dense(fake_mode, func, arg):282if fake_mode.shape_env is None or not fake_mode.shape_env.allow_scalar_outputs:283# Without symints/symfloats, cannot handle this284raise DataDependentOutputException(func)285if is_float_dtype(arg.dtype):286return fake_mode.shape_env.create_unbacked_symfloat()287elif is_integer_dtype(arg.dtype):288return fake_mode.shape_env.create_unbacked_symint()289elif is_boolean_dtype(arg.dtype):290return fake_mode.shape_env.create_unbacked_symbool()291else:292raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}")293
294
295@register_op_impl(torch.ops.aten.nonzero.default)296def nonzero(fake_mode, func, arg):297if (298fake_mode.shape_env is None299or not fake_mode.shape_env.allow_dynamic_output_shape_ops300):301# Without symints/symfloats, cannot handle this302raise DynamicOutputShapeException(func)303
304if arg.nonzero_memo is None:305nnz = fake_mode.shape_env.create_unbacked_symint()306
307# This is unsound, but it works well in practice308# See https://docs.google.com/document/d/1lFRYAJo5nrfxRhwIzGnfi2pbLpU6T4ytSRSuLJ5qebI/edit#309# TODO: Add a config knob to turn off this unsound behavior310#311# NB: If numel < 2, the bounds here might be COMPLETELY312# disjoint with what can actually occur. But this is fine:313# remember, the hypothesis is that if your later code works314# with N >= 2, it will work with N = 1 and N = 0.315maxval = sys.maxsize - 1316
317# Avoid importing sympy at a module level318from torch.fx.experimental.symbolic_shapes import (319_constrain_range_for_size,320has_free_symbols,321)322
323if not has_free_symbols(arg.numel()):324# Don't upgrade the range if numel is less than two, since we then325# have an empty range which makes things go explodey. We also326# don't allow for 2 because that would specialize the unbacked327# SymInt to 2, which is also likely to be buggy.328if arg.numel() > 2:329maxval = int(arg.numel())330
331_constrain_range_for_size(nnz, max=maxval)332
333arg._nonzero_memo = nnz334arg._nonzero_memo_vc = arg._version335
336return arg.new_empty((arg.nonzero_memo, arg.dim()), dtype=torch.int64)337
338
339@register_op_impl(torch.ops.aten.masked_select.default)340def masked_select(fake_mode, func, self, mask):341if (342fake_mode.shape_env is None343or not fake_mode.shape_env.allow_dynamic_output_shape_ops344):345# Without symints/symfloats, cannot handle this346raise DynamicOutputShapeException(func)347
348nnz = fake_mode.shape_env.create_unbacked_symint()349
350# see nonzero for commentary351maxval = sys.maxsize - 1352
353# Avoid importing sympy at a module level354from torch.fx.experimental.symbolic_shapes import (355_constrain_range_for_size,356has_free_symbols,357)358
359if not has_free_symbols(self.numel()):360if self.numel() > 2:361maxval = int(self.numel())362
363_constrain_range_for_size(nnz, max=maxval)364
365return self.new_empty((nnz,))366
367
368# NB: this must be ordered after local_scalar_dense
369@register_op_impl(lambda func: torch.Tag.data_dependent_output in func.tags)370def data_dep(fake_mode, func, *args, **kwargs):371raise DataDependentOutputException(func)372
373
374# Bool Indices get Expanded as Masks
375# See: IndexingUtils.h:expandTensors
376def check_no_bool_index_tensors(func, self, indices):377for index in indices:378if index is not None and index.dtype in (torch.bool, torch.uint8):379raise DynamicOutputShapeException(func)380
381
382def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs):383_, new_kwargs = normalize_function(384func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True385)386
387out_device = new_kwargs["input"].device388with in_kernel_invocation_manager(fake_mode):389out = func(*args, **kwargs)390if not is_noncontiguous_supported(out_device):391out = out.new_empty(out.shape)392
393if out is new_kwargs["input"]:394return out # copy_395return FakeTensor(fake_mode, out, out_device)396
397
398_is_builtin_namespaces = ordered_set("aten", "prims", "prim")399
400
401def is_builtin(op):402return op.namespace in _is_builtin_namespaces403
404
405def has_meta(func):406return torch._C._dispatch_has_computed_kernel_for_dispatch_key(func.name(), "Meta")407
408
409@register_op_impl(410lambda func: is_builtin(func) and "foreach" in func.name() and has_meta(func)411)
412def foreach_run_and_map_input_device(fake_mode, func, *args, **kwargs):413tensor_lists = []414for arg in itertools.chain(args, kwargs.values()):415if (416isinstance(arg, (list, tuple))417and len(arg)418and isinstance(arg[0], torch.Tensor)419):420tensor_lists.append(arg)421
422try:423with in_kernel_invocation_manager(fake_mode):424out_meta = func(*args, **kwargs)425except NotImplementedError as not_implemented_error:426return NotImplemented427
428if not out_meta:429return out_meta430
431assert tensor_lists432out_fake = []433
434for i, meta_t in enumerate(out_meta):435device, _ = FakeTensor._find_common_device(func, [tl[i] for tl in tensor_lists])436out_fake.append(437fake_mode.fake_tensor_converter.from_meta_and_device(438fake_mode, meta_t, device439)440)441
442return out_fake443
444
445# Dont default to default device handling,
446# Since op can take in non-zero sized cpu
447# index tensors with cuda self
448@register_op_impl(aten.index.Tensor)449def index_tensor(fake_mode, func, *args, **kwargs):450from torch._meta_registrations import meta_index_Tensor451
452_, new_kwargs = normalize_function(453func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True454)455
456out_device = new_kwargs["input"].device457# ensure nonzero call goes to fake tensor458with fake_mode:459out = meta_index_Tensor(*args, **kwargs)460return out.to(out_device)461
462
463# Can take mixed meta/non-meta arguments; the meta registration
464# will roughly do the right thing even when given real devices
465@register_op_impl(aten._embedding_bag.default)466def embedding_bag(fake_mode, func, *args, **kwargs):467from torch._meta_registrations import meta_embedding_bag468
469with fake_mode:470return meta_embedding_bag(*args, **kwargs)471
472
473# takes in multiple-devices, dont default to default device handling
474@register_op_impl(aten._unsafe_index_put.default)475@register_op_impl(aten.copy.default)476@register_op_impl(aten.copy_.default)477@register_op_impl(aten.slice_scatter.default)478def multi_device_op_default(fake_mode, func, *args, **kwargs):479return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)480
481
482# same with multi_device_op_default, but return the input
483@register_op_impl(aten.copy.out)484@register_op_impl(aten.slice_scatter.out)485def multi_device_op_out(fake_mode, func, *args, **kwargs):486with in_kernel_invocation_manager(fake_mode):487out = func(*args, **kwargs)488
489_, new_kwargs = normalize_function(490func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True491)492
493return new_kwargs["input"]494
495
496@register_op_impl(aten.index_put.default)497@register_op_impl(aten.index_put_.default)498def index_put_impl(fake_mode, func, *args, **kwargs):499_, new_kwargs = normalize_function(500func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True501)502
503values = new_kwargs["values"]504self_device = new_kwargs["input"].fake_device505torch._check(506self_device == values.fake_device or (values.ndim == 0 and values.numel() == 1),507lambda: f"Mismatching {func} device between self ({self_device}) and values ({values.device})",508)509
510out = run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)511if func is aten.index_put_.default:512return new_kwargs["input"]513else:514return out515
516
517@register_op_impl(aten._nested_tensor_from_tensor_list.default)518@register_op_impl(aten._nested_tensor_from_tensor_list.out)519def nested_tensors_unsupported(fake_mode, func, *args, **kwargs):520raise UnsupportedOperatorException(521"torch.compile does not support strided NestedTensor"522)523
524
525@register_op_impl(526[527x
528for x in _device_not_kwarg_ops529if x530not in (531# these are already registered elsewhere532aten.to.device,533aten.to.prim_Device,534aten._nested_tensor_from_tensor_list.default,535aten._nested_tensor_from_tensor_list.out,536)537]538)
539def nyi(fake_mode, func, *args, **kwargs):540assert func not in _device_not_kwarg_ops, f"NYI: {func}"541
542
543@register_op_impl([aten.convolution.default, aten.convolution_backward.default])544def conv(fake_mode, func, *args, **kwargs):545_, kwargs = normalize_function(546func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True547)548device = kwargs["input"].fake_device549# need to re-enable mode so the tensors report fake device550with fake_mode:551# if the input is unsqueezed is done in Convolution.cpp we get segfault552k = kwargs["weight"].ndim553batch = kwargs["input"].shape[0]554
555# Avoid importing sympy at a module level556from torch.fx.experimental.symbolic_shapes import has_hint557
558if not has_hint(batch):559# TODO: We can make this a little more faithful with best effort560# channels last detection (but only if it's statically obvious!)561mem_fmt = None562elif k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu:563mem_fmt = None564else:565if func is aten.convolution.default:566conv_backend = torch._C._select_conv_backend(**kwargs)567else:568conv_backend = torch._C._select_conv_backend(569kwargs["input"],570kwargs["weight"],571bias=None,572stride=kwargs["stride"],573padding=kwargs["padding"],574dilation=kwargs["dilation"],575transposed=kwargs["transposed"],576output_padding=kwargs["output_padding"],577groups=kwargs["groups"],578bias_sizes=kwargs["bias_sizes"],579)580mem_fmt = torch._C._conv_determine_backend_memory_format(581kwargs["input"], kwargs["weight"], conv_backend582)583
584def convert(t, mem_fmt):585if t is None:586return t587if mem_fmt is not None:588t = t.to(memory_format=mem_fmt)589return FakeTensor(fake_mode, t, device)590
591with in_kernel_invocation_manager(fake_mode):592out = func(**kwargs)593
594if func is aten.convolution.default:595return convert(out, mem_fmt)596else:597return (598convert(out[0], mem_fmt),599convert(out[1], mem_fmt),600convert(out[2], None),601)602
603
604@register_op_impl(aten._scaled_dot_product_flash_attention.default)605def meta__scaled_dot_product_flash(fake_mode, func, *args, **kwargs):606_, kwargs = normalize_function(607func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True608)609
610query = kwargs["query"]611key = kwargs["key"]612return_debug_mask = kwargs["return_debug_mask"]613# unused: value, dropout_p, is_causal, scale614
615def convert_tensor(t, device):616return FakeTensor(fake_mode, t, device)617
618batch_size = query.size(0)619num_heads = query.size(1)620max_seqlen_batch_q = query.size(2)621head_dim = query.size(3)622max_seqlen_batch_k = key.size(2)623
624query_t = query.transpose(1, 2)625# empty_like already returns a fake tensor so we don't need to convert it626attention = torch.empty_like(query_t).transpose(1, 2)627logsumexp = convert_tensor(628torch.empty(629(batch_size, num_heads, max_seqlen_batch_q),630dtype=torch.float,631device="meta",632),633device=query.device,634)635
636if return_debug_mask:637blocksize_c = 128 if head_dim > 64 else 256638max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)639if max_seqlen_batch_k <= 128:640max_seqlen_k = 128641elif max_seqlen_batch_k <= 256:642max_seqlen_k = 256643debug_mask = convert_tensor(644torch.empty(645(batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),646dtype=query.dtype,647device="meta",648),649device=query.device,650)651else:652debug_mask = convert_tensor(653torch.empty(0, dtype=query.dtype, device="meta"),654query.device,655)656
657# Note [Seed and Offset]: device for seed and offset below depends on whether we are658# capturing or not, but at the time of tracing we don't know if we659# are going to use cudagraphs or not, so we return meta tensors here660# it's possible we'll need to have some special handling in inductor for sdpa661
662return (663attention,664logsumexp,665None,666None,667max_seqlen_batch_q,668max_seqlen_batch_k,669convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),670convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),671debug_mask,672)673
674
675@register_op_impl(aten._scaled_dot_product_efficient_attention.default)676def meta__scaled_dot_product_efficient(fake_mode, func, *args, **kwargs):677_, kwargs = normalize_function(678func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True679)680
681query = kwargs["query"]682key = kwargs["key"]683value = kwargs["value"]684compute_log_sumexp = kwargs["compute_log_sumexp"]685# unused: attn_bias, dropout_p, is_causal, scale686
687def convert_tensor(t, device):688return FakeTensor(fake_mode, t, device)689
690query = query.transpose(1, 2)691key = key.transpose(1, 2)692value = value.transpose(1, 2)693
694B = query.size(0)695M = query.size(1)696N = key.size(1)697num_heads = query.size(-2)698K = query.size(-1)699Kv = value.size(-1)700
701res = convert_tensor(702torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"),703query.device,704)705
706logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0707logsum_exp = convert_tensor(708torch.empty(709(B, num_heads, logsumexp_dim),710dtype=torch.float,711device="meta",712),713query.device,714)715
716res = res.transpose(1, 2)717
718# See Note [Seed and Offset]:719seed = convert_tensor(720torch.empty((), dtype=torch.long, device="meta"), query.device721)722offset = convert_tensor(723torch.empty((), dtype=torch.long, device="meta"), query.device724)725
726return res, logsum_exp, seed, offset727
728
729@register_op_impl(aten._flash_attention_forward.default)730def meta__flash_attention_forward(fake_mode, func, *args, **kwargs):731_, kwargs = normalize_function(732func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True733)734
735query = kwargs["query"]736key = kwargs["key"]737cum_seq_q = kwargs["cum_seq_q"]738cum_seq_k = kwargs["cum_seq_k"]739max_q = kwargs["max_q"]740max_k = kwargs["max_k"]741return_debug_mask = kwargs["return_debug_mask"]742# unused: value, dropout_p, is_causal, scale743
744def convert_tensor(t, device):745return FakeTensor(fake_mode, t, device)746
747# NB: there are two underlying paths:748# 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim)749# 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total750# includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total751batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1752max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q753max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k754num_heads = query.size(-2)755head_dim = query.size(-1)756
757# Cuda Path758# note: empty_like already returns a fake tensor, we don't need to wrap it759attention = torch.empty_like(query)760logsumexp = convert_tensor(761torch.empty(762(batch_size, num_heads, max_seqlen_batch_q),763dtype=torch.float,764device="meta",765),766device=query.device,767)768
769if return_debug_mask:770blocksize_c = 128 if head_dim > 64 else 256771max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)772if max_seqlen_batch_k <= 128:773max_seqlen_k = 128774elif max_seqlen_batch_k <= 256:775max_seqlen_k = 256776debug_mask = convert_tensor(777torch.empty(778(batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),779dtype=query.dtype,780device="meta",781),782query.device,783)784else:785debug_mask = convert_tensor(786torch.empty(0, dtype=query.dtype, device="meta"),787query.device,788)789
790# See Note [Seed and Offset]:791return (792attention,793logsumexp,794convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),795convert_tensor(torch.empty((), dtype=torch.long, device="meta"), query.device),796debug_mask,797)798
799
800@register_op_impl(aten._efficient_attention_forward.default)801def meta__efficient_attention_forward(fake_mode, func, *args, **kwargs):802_, kwargs = normalize_function(803func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True804)805
806query = kwargs["query"]807key = kwargs["key"]808value = kwargs["value"]809cu_seqlens_q = kwargs["cu_seqlens_q"]810max_seqlen_q = kwargs["max_seqlen_q"]811max_seqlen_k = kwargs["max_seqlen_k"]812compute_log_sumexp = kwargs["compute_log_sumexp"]813# unused: bias, cu_seqlens_k, dropout_p, custom_mask_type, scale, causal_diagonal, seqlen_k814
815def convert_tensor(t, device):816return FakeTensor(fake_mode, t, device)817
818B = query.size(0)819M = query.size(1)820N = key.size(1)821num_heads = query.size(-2)822K = query.size(-1)823Kv = value.size(-1)824
825res = convert_tensor(826torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device="meta"),827query.device,828)829
830logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B831actual_max_seqlen_q = M832if cu_seqlens_q is not None:833assert max_seqlen_q is not None834actual_max_seqlen_q = max_seqlen_q835actual_max_seqlen_k = max_seqlen_k if max_seqlen_k is not None else N836logsumexp_dim = (837math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0838)839logsum_exp = convert_tensor(840torch.empty(841(logsumexp_batch_dim, num_heads, logsumexp_dim),842dtype=torch.float,843device="meta",844),845query.device,846)847
848# See Note [Seed and Offset]:849seed = convert_tensor(850torch.empty((), dtype=torch.long, device="meta"), query.device851)852offset = convert_tensor(853torch.empty((), dtype=torch.long, device="meta"), query.device854)855
856return res, logsum_exp, seed, offset, actual_max_seqlen_q, actual_max_seqlen_k857
858
859FAST_OP_IMPLEMENTATIONS = {}860
861
862# Unlike register_op_impl, these don't do the slow iteration for
863# run_impl_check, and these run BEFORE decompositions
864def register_fast_op_impl(func: OpOverload):865def impl_decorator(op_impl):866FAST_OP_IMPLEMENTATIONS[func] = op_impl867return op_impl868
869return impl_decorator870
871
872# infer_size_impl in ExpandUtils
873def infer_size(a, b):874from torch.fx.experimental.symbolic_shapes import guard_size_oblivious875
876dimsA = len(a)877dimsB = len(b)878ndim = max(dimsA, dimsB)879expandedSizes = [0] * ndim880for i in range(ndim - 1, -1, -1):881offset = ndim - 1 - i882dimA = dimsA - 1 - offset883dimB = dimsB - 1 - offset884sizeA = a[dimA] if dimA >= 0 else 1885sizeB = b[dimB] if dimB >= 0 else 1886
887# NB: It is very important to test for broadcasting, before testing888# sizeA == sizeB. This is because the broadcasting tests are likely889# to be statically known (in particular, if sizeA/sizeB is unbacked890# but size-like, we will unsoundly assume they never equal 1), but891# the sizeA == sizeB test may not be statically known. However, once892# we have established that no broadcasting is happening, the893# sizeA == sizeB is now expect_true and we can defer it as a runtime894# assert (this works because Python will return the terminal895# expression of an or statement as-is, without bool()'ing it; if this896# were not the case, we'd need to write this using torch.sym_or() or897# something like that).898torch._check(899guard_size_oblivious(sizeA == 1)900or guard_size_oblivious(sizeB == 1)901or sizeA == sizeB,902lambda: f"The size of tensor a ({sizeA}) "903f"must match the size of tensor b ({sizeB}) "904f"at non-singleton dimension {i})",905)906expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA907return tuple(expandedSizes)908
909
910def make_fast_binary_impl(slow_ref):911def fast_binary_impl(mode, *args, **kwargs):912def slow(msg):913count_label(f"slow {msg}")914with mode:915return slow_ref(*args, **kwargs)916
917count_label("attempt fast")918
919# Fast path (based off of TensorIterator fast path).920# Unfortunately, there is no way to easily deduplicate921# this with either the TensorIterator C++ implementation922# (which we don't want to SymIntify, and also the algorithm923# here is slightly different from TensorIterator to allow924# for broadcasting), nor the PrimTorch implementation925# (which does not actually implement a fast path.)926
927operands = args928
929# compute_shape930has_scalars = False931has_tensors = False932final_shape = None933for op in operands:934shape = op.shape if isinstance(op, torch.Tensor) else ()935if len(shape) == 0:936has_scalars = True937else:938has_tensors = True939if final_shape is None:940final_shape = shape941# TODO: Minor optimization: track if the shapes942# were equal so you can skip the equality check943# below if unnecessary944final_shape = infer_size(final_shape, shape)945assert final_shape is not None946
947# Do some extra safety checks to see if the output948# stride is obvious949for op in operands:950if isinstance(op, torch.Tensor) and op.shape == final_shape:951break952else:953return slow("both tensors nontrivially broadcast")954
955# compute_types956cpu = torch.device("cpu")957common_device = cpu958common_dtype = None959output_dtype = None960has_different_input_dtypes = False961for op in operands:962if not isinstance(op, torch.Tensor):963# Use elementwise_dtypes for the tricky case964has_different_input_dtypes = True965continue966if common_device == cpu and not op.device.type == "cpu":967common_device = op.device968# Slightly simplified here as target_dtype cannot vary969if common_dtype is None:970common_dtype = op.dtype971elif common_dtype != op.dtype:972has_different_input_dtypes = True973
974if has_different_input_dtypes:975# compute promotion976# TODO: we don't need the compute type977_, common_dtype = elementwise_dtypes(978*operands, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT979)980
981# check all tensors on same device982# cpu scalars are assumed allow983current_cpu_scalars_on_non_cpu = 0984max_cpu_scalars_on_non_cpu = 1 # hard coded atm985for op in operands:986if not isinstance(op, torch.Tensor):987continue988if common_device != cpu and op.dim() == 0 and op.device == cpu:989if current_cpu_scalars_on_non_cpu >= max_cpu_scalars_on_non_cpu:990return slow("error")991current_cpu_scalars_on_non_cpu += 1992elif op.device != common_device:993return slow("error")994
995# compute_fast_setup_type996is_contiguous = True997is_channels_last = True998# TODO: is_non-overlapping_and_dense (not bound from Python999# no inplace, no out, everything defined1000
1001if is_noncontiguous_supported(common_device):1002for op in operands:1003if not isinstance(op, torch.Tensor):1004continue1005is_contiguous = is_contiguous and op.is_contiguous(1006memory_format=torch.contiguous_format1007)1008is_channels_last = is_channels_last and op.is_contiguous(1009memory_format=torch.channels_last1010)1011if is_contiguous:1012# do contiguous1013count_label("fast is_contiguous")1014return FakeTensor(1015mode,1016torch.empty(1017final_shape,1018dtype=common_dtype,1019device="meta",1020memory_format=torch.contiguous_format,1021),1022device=common_device,1023)1024if is_channels_last:1025count_label("fast channels_last")1026# do channels last1027return FakeTensor(1028mode,1029torch.empty(1030final_shape,1031dtype=common_dtype,1032device="meta",1033memory_format=torch.channels_last,1034),1035device=common_device,1036)1037
1038return slow("no contiguity match")1039
1040return fast_binary_impl1041
1042
1043@functools.lru_cache(None)1044def get_fast_op_impls():1045import torch._refs1046
1047register_fast_op_impl(torch.ops.aten.add.Tensor)(1048make_fast_binary_impl(torch._refs.add)1049)1050register_fast_op_impl(torch.ops.aten.sub.Tensor)(1051make_fast_binary_impl(torch._refs.sub)1052)1053register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type]1054register_fast_op_impl(torch.ops.aten.div.Tensor)(1055make_fast_binary_impl(torch._refs.div)1056)1057return FAST_OP_IMPLEMENTATIONS1058