pytorch
996 строк · 33.5 Кб
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import math4from typing import Any, Callable, Dict, Tuple, Union5
6import torch7import torch.utils._pytree as pytree8from torch._C import DispatchKey9from torch._higher_order_ops.utils import (10_has_potential_branch_input_mutation,11autograd_not_implemented,12reenter_make_fx,13UnsupportedAliasMutationException,14)
15from torch._ops import HigherOrderOperator16from torch._subclasses import FakeTensorMode17from torch.fx.experimental.proxy_tensor import (18make_fx,19ProxyTorchDispatchMode,20track_tensor_tree,21)
22from torch.fx.graph_module import GraphModule23from torch.overrides import TorchFunctionMode24
25
26class TransformGetItemToIndex(TorchFunctionMode):27# This is needed since we want to support calling28# A[q_idx], where q_idx is a scalar tensor in score_mod.29# Today, when q_idx is a scalar tensor, we implicitly convert it to a python30# scalar and create a view. We do not want that behavior in this case, so we31# use this torchfunctionmode to override that behavior for score_mod32# wherever we're running it.33def __torch_function__(self, func, types, args, kwargs=None):34if func == torch.Tensor.__getitem__:35index_args = pytree.tree_leaves(args[1])36if all(isinstance(x, torch.Tensor) for x in index_args):37return torch.ops.aten.index(args[0], index_args)38return func(*args, **(kwargs or {}))39
40
41class FlexAttentionHOP(HigherOrderOperator):42def __init__(self) -> None:43super().__init__("flex_attention")44
45def __call__(46self,47query: torch.Tensor,48key: torch.Tensor,49value: torch.Tensor,50score_mod: Callable,51block_mask: Tuple,52scale: float,53kernel_options: Dict[str, Any],54score_mod_other_buffers: Tuple = (),55mask_mod_other_buffers: Tuple = (),56) -> Tuple[torch.Tensor, torch.Tensor]:57if not all(58isinstance(buf, torch.Tensor)59for buf in score_mod_other_buffers + mask_mod_other_buffers60):61raise RuntimeError("Other buffers must be tensors.")62return super().__call__(63query,64key,65value,66score_mod,67block_mask,68scale,69kernel_options,70score_mod_other_buffers,71mask_mod_other_buffers,72)73
74
75flex_attention = FlexAttentionHOP()76
77
78class FlexAttentionBackwardHOP(HigherOrderOperator):79def __init__(self) -> None:80super().__init__("flex_attention_backward")81
82def __call__(83self,84query: torch.Tensor,85key: torch.Tensor,86value: torch.Tensor,87out: torch.Tensor,88logsumexp: torch.Tensor,89grad_out: torch.Tensor,90grad_logsumexp: torch.Tensor,91fw_graph: Union[Callable, GraphModule],92joint_graph: GraphModule,93block_mask: Tuple,94scale: float,95kernel_options: Dict[str, Any],96score_mod_other_buffers: Tuple = (),97mask_mod_other_buffers: Tuple = (),98) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:99if not all(100isinstance(buf, torch.Tensor)101for buf in score_mod_other_buffers + mask_mod_other_buffers102):103raise RuntimeError("Other buffers must be tensors.")104return super().__call__(105query,106key,107value,108out,109logsumexp,110grad_out,111grad_logsumexp,112fw_graph,113joint_graph,114block_mask,115scale,116kernel_options,117score_mod_other_buffers,118mask_mod_other_buffers,119)120
121
122flex_attention_backward = FlexAttentionBackwardHOP()123
124
125def _math_attention_inner(126query: torch.Tensor,127key: torch.Tensor,128value: torch.Tensor,129score_mod: Callable,130block_mask: Tuple,131scale: float,132kernel_options: Dict[str, Any],133score_mod_other_buffers: Tuple = (),134mask_mod_other_buffers: Tuple = (),135) -> Tuple[torch.Tensor, torch.Tensor]:136working_precision = torch.float64 if query.dtype == torch.float64 else torch.float32137
138scores = (query @ key.transpose(-2, -1)).to(dtype=working_precision)139
140b = torch.arange(0, scores.size(0), device=scores.device)141h = torch.arange(0, scores.size(1), device=scores.device)142m = torch.arange(0, scores.size(2), device=scores.device)143n = torch.arange(0, scores.size(3), device=scores.device)144
145captured_buffers_in_dim = (None,) * len(score_mod_other_buffers)146from torch.nn.attention.flex_attention import _vmap_for_bhqkv147
148# first input is score149score_mod = _vmap_for_bhqkv(score_mod, prefix=(0,), suffix=captured_buffers_in_dim)150
151mask_mod = block_mask[-1]152mask_mod_in_dim_buffers = (None,) * len(mask_mod_other_buffers)153mask_mod = _vmap_for_bhqkv(mask_mod, prefix=(), suffix=mask_mod_in_dim_buffers)154
155with TransformGetItemToIndex():156scores = (scores * scale).to(working_precision)157post_mod_scores = torch.where(158mask_mod(b, h, m, n, *mask_mod_other_buffers),159score_mod(scores, b, h, m, n, *score_mod_other_buffers),160torch.tensor(-float("inf"), dtype=working_precision, device=scores.device),161)162
163return scores, post_mod_scores164
165
166def math_attention(167query: torch.Tensor,168key: torch.Tensor,169value: torch.Tensor,170score_mod: Callable,171block_mask: Tuple,172scale: float,173kernel_options: Dict[str, Any],174score_mod_other_buffers: Tuple = (),175mask_mod_other_buffers: Tuple = (),176) -> Tuple[torch.Tensor, torch.Tensor]:177"""Eager implementation178
179This implementation uses vmap to vectorize the score_mod function over the batch, head, m, and n dimensions.
180We then apply the vectorized score_mod function to the scores matrix. Each wrap of vmap applies one of the
181batch, head, m, or n dimensions. We need to apply vmap 4 times to vectorized over all 4 dimensions.
182
183Args:
184query: The query tensor
185key: The key tensor
186value: The value tensor
187score_mod: The score_mod function
188other_buffers: Other buffers that are passed to the score_mod function
189"""
190# broadcast query & key along head dim for GQA191G = query.size(1) // key.size(1)192value = torch.repeat_interleave(value, G, dim=1)193key = torch.repeat_interleave(key, G, dim=1)194
195_, post_mod_scores = _math_attention_inner(196query,197key,198value,199score_mod,200block_mask,201scale,202kernel_options,203score_mod_other_buffers,204mask_mod_other_buffers,205)206
207# Set fully masked rows' sumexp to 0.0208logsumexp = post_mod_scores.logsumexp(dim=-1)209masked_rows = torch.all(post_mod_scores == -float("inf"), dim=-1)210logsumexp = torch.where(masked_rows, -float("inf"), logsumexp)211
212post_mod_scores = torch._safe_softmax(post_mod_scores, dim=-1)213
214return post_mod_scores.to(query.dtype) @ value, logsumexp / math.log(2)215
216
217@flex_attention.py_impl(DispatchKey.CompositeExplicitAutograd)218def sdpa_dense(219query: torch.Tensor,220key: torch.Tensor,221value: torch.Tensor,222score_mod: Callable,223block_mask: Tuple,224scale: float,225kernel_options: Dict[str, Any],226score_mod_other_buffers: Tuple = (),227mask_mod_other_buffers: Tuple = (),228) -> Tuple[torch.Tensor, torch.Tensor]:229out, lse = math_attention(230query,231key,232value,233score_mod,234block_mask,235scale,236kernel_options,237score_mod_other_buffers,238mask_mod_other_buffers,239)240out = out.contiguous()241return out, lse242
243
244def trace_flex_attention(245proxy_mode: ProxyTorchDispatchMode,246query: torch.Tensor,247key: torch.Tensor,248value: torch.Tensor,249score_mod: Callable,250block_mask: Tuple,251scale: float,252kernel_options: Dict[str, Any],253score_mod_other_buffers: Tuple = (),254mask_mod_other_buffers: Tuple = (),255) -> Tuple[torch.Tensor, torch.Tensor]:256"""Traces the flex_attention operator with the given score_mod function and other_buffers.257
258Trace SDPA will call make_fx with "fake" example vals and then trace the score_mod function
259This will produce a GraphModule that will be stored on the root tracer as "sdpa_score". We
260access this graph module in inductor to inline the score_mod function to the triton template.
261"""
262example_out = flex_attention(263query,264key,265value,266score_mod,267block_mask,268scale,269kernel_options,270score_mod_other_buffers,271mask_mod_other_buffers,272)273example_vals = [274torch.zeros((), dtype=query.dtype, requires_grad=query.requires_grad)275] + [torch.zeros((), dtype=torch.int) for _ in range(4)]276mask_example_vals = [torch.zeros((), dtype=torch.int) for _ in range(4)]277mask_mod = block_mask[-1]278with TransformGetItemToIndex():279score_graph = reenter_make_fx(score_mod)(280*example_vals, *score_mod_other_buffers281)282mask_graph = reenter_make_fx(mask_mod)(283*mask_example_vals, *mask_mod_other_buffers284)285assert isinstance(proxy_mode.tracer, torch.fx.Tracer)286block_mask = block_mask[:-1] + (mask_graph,)287qualname = proxy_mode.tracer.get_fresh_qualname("sdpa_score")288proxy_mode.tracer.root.register_module(qualname, score_graph)289mask_qualname = proxy_mode.tracer.get_fresh_qualname("sdpa_mask")290proxy_mode.tracer.root.register_module(mask_qualname, mask_graph)291node_args = (292query,293key,294value,295score_graph,296block_mask,297scale,298kernel_options,299score_mod_other_buffers,300mask_mod_other_buffers,301)302proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)303out_proxy = proxy_mode.tracer.create_proxy(304"call_function", flex_attention, proxy_args, {}305)306return track_tensor_tree(307example_out, out_proxy, constant=None, tracer=proxy_mode.tracer308)309
310
311@flex_attention.py_impl(ProxyTorchDispatchMode)312def flex_attention_proxy_torch_dispatch_mode(313mode: ProxyTorchDispatchMode,314query: torch.Tensor,315key: torch.Tensor,316value: torch.Tensor,317score_mod: Callable,318block_mask: Tuple,319scale: float,320kernel_options: Dict[str, Any],321score_mod_other_buffers: Tuple = (),322mask_mod_other_buffers: Tuple = (),323) -> Tuple[torch.Tensor, torch.Tensor]:324assert mode is not None, "Mode should always be enabled for python fallback key"325return trace_flex_attention(326mode,327query,328key,329value,330score_mod,331block_mask,332scale,333kernel_options,334score_mod_other_buffers,335mask_mod_other_buffers,336)337
338
339@flex_attention.py_functionalize_impl340def flex_attention_functionalize(341ctx: torch._subclasses.functional_tensor.BaseFunctionalizeAPI,342query: torch.Tensor,343key: torch.Tensor,344value: torch.Tensor,345score_mod: Callable,346block_mask: Tuple,347scale: float,348kernel_options: Dict[str, Any],349score_mod_other_buffers: Tuple = (),350mask_mod_other_buffers: Tuple = (),351) -> Tuple[torch.Tensor, torch.Tensor]:352"""Defines the functionalization rules for the flex_attention operator.353
354Write now we are unwrapping each tensor and then redispatching to the next, however we want to
355guard against any mutations in the score_mod function, to the other_buffers since those
356are free variables.
357"""
358query_unwrapped = ctx.unwrap_tensors(query)359key_unwrapped = ctx.unwrap_tensors(key)360value_unwrapped = ctx.unwrap_tensors(value)361block_mask_unwrapped = ctx.unwrap_tensors(block_mask)362score_mod_other_buffers_unwrapped = ctx.unwrap_tensors(score_mod_other_buffers)363mask_mod_other_buffers_unwrapped = ctx.unwrap_tensors(mask_mod_other_buffers)364
365# Appease the mypy overlords366assert isinstance(query_unwrapped, torch.Tensor)367assert isinstance(key_unwrapped, torch.Tensor)368assert isinstance(value_unwrapped, torch.Tensor)369assert isinstance(block_mask_unwrapped, tuple)370assert isinstance(score_mod_other_buffers_unwrapped, tuple)371assert isinstance(mask_mod_other_buffers_unwrapped, tuple)372assert all(373isinstance(item, torch.Tensor)374for item in score_mod_other_buffers_unwrapped + mask_mod_other_buffers_unwrapped375)376
377example_vals = (378[torch.zeros((), dtype=query.dtype)]379+ [torch.zeros((), dtype=torch.int) for _ in range(4)]380+ list(score_mod_other_buffers_unwrapped)381)382with ctx.redispatch_to_next() as m:383functional_score_mod = ctx.functionalize(score_mod)384pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch385with TransformGetItemToIndex():386mutates = _has_potential_branch_input_mutation(387functional_score_mod, example_vals, pre_dispatch388)389# The only care about mutations of existing buffers since we can't replay these.390# However, we can just error if anything is detected391if mutates:392raise UnsupportedAliasMutationException("Mutations detected in score_mod")393
394out = flex_attention(395query_unwrapped,396key_unwrapped,397value_unwrapped,398functional_score_mod,399block_mask_unwrapped,400scale,401kernel_options,402score_mod_other_buffers_unwrapped,403mask_mod_other_buffers_unwrapped,404)405return ctx.wrap_tensors(out) # type: ignore[return-value, arg-type]406
407
408@flex_attention.py_impl(FakeTensorMode)409def flex_attention_fake_tensor_mode(410mode: FakeTensorMode,411query: torch.Tensor,412key: torch.Tensor,413value: torch.Tensor,414score_mod: Callable,415block_mask: Tuple,416scale: float,417kernel_options: Dict[str, Any],418score_mod_other_buffers: Tuple = (),419mask_mod_other_buffers: Tuple = (),420) -> Tuple[torch.Tensor, torch.Tensor]:421with mode:422v_head_dim = value.size(-1)423batch_size, num_heads, seq_len_q, q_head_dim = query.shape424logsumexp = query.new_empty(425batch_size, num_heads, seq_len_q, dtype=torch.float32426)427out_shape = (batch_size, num_heads, seq_len_q, v_head_dim)428return query.new_empty(out_shape), logsumexp429
430
431# ---------------------------- Autograd Implementation ----------------------------
432def create_fw_bw_graph(score_mod, index_values, other_buffers):433# See Note:[HOP create fw_bw graph]434
435# All of these imports need to be here in order to avoid circular dependencies436from torch._dispatch.python import suspend_functionalization437from torch._functorch.aot_autograd import AOTConfig, create_joint438from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode439from torch._subclasses.functional_tensor import disable_functional_mode440from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing441
442dummy_aot_config = AOTConfig(443fw_compiler=None, # type: ignore[arg-type]444bw_compiler=None, # type: ignore[arg-type]445partition_fn=None, # type: ignore[arg-type]446decompositions={},447num_params_buffers=0,448aot_id=0,449keep_inference_input_mutations=False,450)451
452with suspend_functionalization(), disable_functional_mode():453with disable_proxy_modes_tracing():454
455def _from_fun(t):456return torch.empty_strided(457t.size(),458t.stride(),459device=t.device,460dtype=t.dtype,461requires_grad=t.requires_grad,462)463
464# If someone runs this hop under the default compiler backend ("eager")465# Then this path will be run with the actual user inputs. We convert them466# to fake tensors in order to not perform any actual compute.467from torch._guards import detect_fake_mode468
469fake_mode = detect_fake_mode(index_values)470if fake_mode is None:471fake_mode = FakeTensorMode(allow_non_fake_inputs=True)472
473with fake_mode:474unwrapped_score_mod_indexes = pytree.tree_map(_from_fun, index_values)475unwrapped_other_buffers = pytree.tree_map(_from_fun, other_buffers)476
477assert all(isinstance(t, FakeTensor) for t in unwrapped_score_mod_indexes)478assert all(isinstance(t, FakeTensor) for t in unwrapped_other_buffers)479
480example_flat_out = pytree.tree_map(481_from_fun,482score_mod(*unwrapped_score_mod_indexes, *unwrapped_other_buffers),483)484if not isinstance(example_flat_out, torch.Tensor):485raise RuntimeError(486"Expected output of score_mod to be a tensor."487f"Got type {type(example_flat_out)}."488)489example_grad = _from_fun(example_flat_out)490
491def joint_f(score, b, h, m, n, example_grad, *other_buffers):492def fw_with_masks(*args):493fw_out = score_mod(*args)494out_requires_grad = fw_out.requires_grad495return ((fw_out,), (out_requires_grad,))496
497joint = create_joint(fw_with_masks, aot_config=dummy_aot_config)498args = [score, b, h, m, n] + list(other_buffers)499optional_grad = [example_grad] if example_grad.requires_grad else []500_, grads = joint(args, optional_grad)501
502return grads503
504joint_graph = make_fx(joint_f)(505*unwrapped_score_mod_indexes, example_grad, *unwrapped_other_buffers506)507return score_mod, joint_graph508
509
510class FlexAttentionAutogradOp(torch.autograd.Function):511@staticmethod512def forward(513ctx,514query,515key,516value,517fw_graph,518joint_graph,519block_mask,520scale,521kernel_options,522score_mod_other_buffers,523mask_mod_other_buffers,524) -> Tuple[torch.Tensor, torch.Tensor]:525any_buffer_requires_grad = any(526buffer.requires_grad527for buffer in score_mod_other_buffers + mask_mod_other_buffers528)529assert (530not any_buffer_requires_grad531), "Captured buffers that require grad are not yet supported."532ctx._fw_graph = fw_graph533ctx._joint_graph = joint_graph534ctx._mask_graph = block_mask[-1]535# KV_BLOCK_SIZE and Q_BLOCK_SIZE are integers, so can't use ctx.save_for_backward536ctx._KV_BLOCK_SIZE = block_mask[8]537ctx._Q_BLOCK_SIZE = block_mask[9]538ctx.scale = scale539ctx.kernel_options = kernel_options540ctx._score_mod_other_buffers_len = len(score_mod_other_buffers)541with torch._C._AutoDispatchBelowAutograd():542out, logsumexp = flex_attention(543query,544key,545value,546fw_graph,547block_mask,548scale,549kernel_options,550score_mod_other_buffers,551mask_mod_other_buffers,552)553
554ctx.save_for_backward(555query,556key,557value,558out,559logsumexp,560*block_mask[:8],561*score_mod_other_buffers,562*mask_mod_other_buffers,563)564return out, logsumexp565
566@staticmethod567def backward(ctx, grad_out, grad_logsumexp):568fw_args = ctx.saved_tensors569(570query,571key,572value,573out,574logsumexp,575kv_num_blocks,576kv_indices,577full_kv_num_blocks,578full_kv_indices,579q_num_blocks,580q_indices,581full_q_num_blocks,582full_q_indices,583*other_buffers,584) = fw_args585fw_graph = ctx._fw_graph586joint_graph = ctx._joint_graph587mask_graph = ctx._mask_graph588KV_BLOCK_SIZE = ctx._KV_BLOCK_SIZE589Q_BLOCK_SIZE = ctx._Q_BLOCK_SIZE590scale = ctx.scale591kernel_options = ctx.kernel_options592score_mod_other_buffers = tuple(593other_buffers[: ctx._score_mod_other_buffers_len]594)595mask_mod_other_buffers = tuple(596other_buffers[ctx._score_mod_other_buffers_len :]597)598# We have asserted that other_buffers do not require grad in the forward599none_grads = [None] * 7600grad_query, grad_key, grad_value = flex_attention_backward(601query,602key,603value,604out,605logsumexp,606grad_out,607grad_logsumexp,608fw_graph,609joint_graph,610(611kv_num_blocks,612kv_indices,613full_kv_num_blocks,614full_kv_indices,615q_num_blocks,616q_indices,617full_q_num_blocks,618full_q_indices,619KV_BLOCK_SIZE,620Q_BLOCK_SIZE,621mask_graph,622),623scale,624kernel_options,625score_mod_other_buffers,626mask_mod_other_buffers,627)628return grad_query, grad_key, grad_value, *none_grads629
630
631@flex_attention.py_impl(DispatchKey.Autograd)632def flex_attention_autograd(633query: torch.Tensor,634key: torch.Tensor,635value: torch.Tensor,636score_mod: Callable,637block_mask: Tuple,638scale: float,639kernel_options: Dict[str, Any],640score_mod_other_buffers: Tuple = (),641mask_mod_other_buffers: Tuple = (),642) -> Tuple[torch.Tensor, torch.Tensor]:643with TransformGetItemToIndex():644input_requires_grad = any(t.requires_grad for t in (query, key, value))645if torch.is_grad_enabled() and input_requires_grad:646example_vals = [647torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad)648] + [torch.zeros((), dtype=torch.int) for _ in range(4)]649fw_graph, bw_graph = create_fw_bw_graph(650score_mod, example_vals, score_mod_other_buffers651)652else:653fw_graph, bw_graph = score_mod, None654out, logsumexp = FlexAttentionAutogradOp.apply(655query,656key,657value,658fw_graph,659bw_graph,660block_mask,661scale,662kernel_options,663score_mod_other_buffers,664mask_mod_other_buffers,665)666return out, logsumexp667
668
669# ---------------------------- Backward HOP Implementation ----------------------------
670
671
672@flex_attention_backward.py_impl(DispatchKey.CompositeExplicitAutograd)673def sdpa_dense_backward(674query: torch.Tensor,675key: torch.Tensor,676value: torch.Tensor,677out: torch.Tensor,678logsumexp: torch.Tensor,679grad_out: torch.Tensor,680grad_logsumexp: torch.Tensor,681fw_graph: Callable, # GraphModule type hint?682joint_graph: Callable,683block_mask: Tuple,684scale: float,685kernel_options: Dict[str, Any],686score_mod_other_buffers: Tuple,687mask_mod_other_buffers: Tuple,688) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:689G = query.size(1) // key.size(1)690key = torch.repeat_interleave(key, G, dim=1)691value = torch.repeat_interleave(value, G, dim=1)692
693# We're undoing the log -> log2 change of base in the forwards694logsumexp = logsumexp * math.log(2)695# The backwards formula for the log -> log2 change of base in the forwards696grad_logsumexp = grad_logsumexp / math.log(2)697scores, post_mod_scores = _math_attention_inner(698query,699key,700value,701fw_graph,702block_mask,703scale,704kernel_options,705score_mod_other_buffers,706mask_mod_other_buffers,707)708masked_out_rows = logsumexp == -float("inf")709softmax_scores = torch.exp(post_mod_scores - logsumexp.unsqueeze(-1))710softmax_scores = torch.where(masked_out_rows.unsqueeze(-1), 0, softmax_scores)711
712grad_value = softmax_scores.to(query.dtype).transpose(-2, -1) @ grad_out713
714grad_softmax_scores = grad_out @ value.transpose(-2, -1)715
716sum_scores = torch.sum(out * grad_out, -1, keepdim=True)717grad_score_mod = softmax_scores * (718grad_softmax_scores - sum_scores + grad_logsumexp.unsqueeze(-1)719)720
721b = torch.arange(0, scores.size(0), device=scores.device)722h = torch.arange(0, scores.size(1), device=scores.device)723m = torch.arange(0, scores.size(2), device=scores.device)724n = torch.arange(0, scores.size(3), device=scores.device)725
726mask_graph = block_mask[-1]727# Gradient of the inline score_mod function, with respect to the scores728captured_buffers_in_dim = (None,) * len(score_mod_other_buffers)729out_dims = [0, None, None, None, None] + [None] * len(score_mod_other_buffers)730from torch.nn.attention.flex_attention import _vmap_for_bhqkv731
732# inputs are [score, b, h, q_idx, kv_idx, gradOut, ...]733# score and gradOut are "fully" batched734joint_score_mod = _vmap_for_bhqkv(735joint_graph,736prefix=(0,),737suffix=(0,) + captured_buffers_in_dim,738out_dims=out_dims,739)740with TransformGetItemToIndex():741grad_scores, *_ = joint_score_mod(742scores, b, h, m, n, grad_score_mod, *score_mod_other_buffers743)744grad_scores = grad_scores * scale745grad_scores = grad_scores.to(query.dtype)746
747mask_mod = _vmap_for_bhqkv(748mask_graph, prefix=(), suffix=(None,) * len(mask_mod_other_buffers)749)750with TransformGetItemToIndex():751mask_scores = mask_mod(b, h, m, n, *mask_mod_other_buffers)752grad_scores = torch.where(753mask_scores, grad_scores, torch.tensor(0, dtype=query.dtype)754)755
756grad_query = grad_scores @ key757grad_key = grad_scores.transpose(-2, -1) @ query758
759# Reduce DK, DV along broadcasted heads.760grad_key = grad_key.view(761grad_key.size(0), -1, G, grad_key.size(-2), grad_key.size(-1)762)763grad_value = grad_value.view(764grad_value.size(0), -1, G, grad_value.size(-2), grad_value.size(-1)765)766
767grad_key = torch.sum(grad_key, 2, keepdim=False)768grad_value = torch.sum(grad_value, 2, keepdim=False)769
770return grad_query.contiguous(), grad_key.contiguous(), grad_value.contiguous()771
772
773def trace_flex_attention_backward(774proxy_mode: ProxyTorchDispatchMode,775query: torch.Tensor,776key: torch.Tensor,777value: torch.Tensor,778out: torch.Tensor,779logsumexp: torch.Tensor,780grad_out: torch.Tensor,781grad_logsumexp: torch.Tensor,782fw_graph: Union[Callable, GraphModule],783joint_graph: GraphModule,784block_mask: Tuple,785scale: float,786kernel_options: Dict[str, Any],787score_mod_other_buffers: Tuple = (),788mask_mod_other_buffers: Tuple = (),789) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:790"""We already have the forward graph and joint graph from the forward pass, so we create a proxy attach both graphs"""791example_out = flex_attention_backward(792query,793key,794value,795out,796logsumexp,797grad_out,798grad_logsumexp,799fw_graph,800joint_graph,801block_mask,802scale,803kernel_options,804score_mod_other_buffers,805mask_mod_other_buffers,806)807
808fw_example_vals = [809torch.zeros((), dtype=query.dtype, requires_grad=query.requires_grad)810] + [torch.zeros((), dtype=torch.int) for _ in range(4)]811bw_example_vals = fw_example_vals + [torch.zeros((), dtype=query.dtype)]812mask_example_vals = [torch.zeros((), dtype=torch.int) for _ in range(4)]813mask_graph = block_mask[-1]814with TransformGetItemToIndex():815fw_graph = reenter_make_fx(fw_graph)(*fw_example_vals, *score_mod_other_buffers)816joint_graph = reenter_make_fx(joint_graph)(817*bw_example_vals, *score_mod_other_buffers818)819mask_graph = reenter_make_fx(mask_graph)(820*mask_example_vals, *mask_mod_other_buffers821)822assert isinstance(proxy_mode.tracer, torch.fx.Tracer)823block_mask = block_mask[:-1] + (mask_graph,)824proxy_mode.tracer.root.register_module("fw_graph", fw_graph) # type: ignore[arg-type]825proxy_mode.tracer.root.register_module("joint_graph", joint_graph)826proxy_mode.tracer.root.register_module("mask_graph", mask_graph)827node_args = (828query,829key,830value,831out,832logsumexp,833grad_out,834grad_logsumexp,835fw_graph,836joint_graph,837block_mask,838scale,839kernel_options,840score_mod_other_buffers,841mask_mod_other_buffers,842)843proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)844out_proxy = proxy_mode.tracer.create_proxy(845"call_function",846flex_attention_backward,847proxy_args,848{},849name="flex_attention_backward",850)851return track_tensor_tree(852example_out, out_proxy, constant=None, tracer=proxy_mode.tracer853)854
855
856@flex_attention_backward.py_impl(ProxyTorchDispatchMode)857def flex_attention_backward_proxy_torch_dispatch_mode(858mode: ProxyTorchDispatchMode,859query: torch.Tensor,860key: torch.Tensor,861value: torch.Tensor,862out: torch.Tensor,863logsumexp: torch.Tensor,864grad_out: torch.Tensor,865grad_logsumexp: torch.Tensor,866fw_graph: Union[Callable, GraphModule],867joint_graph: GraphModule,868block_mask: Tuple,869scale: float,870kernel_options: Dict[str, Any],871score_mod_other_buffers: Tuple = (),872mask_mod_other_buffers: Tuple = (),873) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:874assert mode is not None, "Mode should always be enabled for python fallback key"875return trace_flex_attention_backward(876mode,877query,878key,879value,880out,881logsumexp,882grad_out,883grad_logsumexp,884fw_graph,885joint_graph,886block_mask,887scale,888kernel_options,889score_mod_other_buffers,890mask_mod_other_buffers,891)892
893
894@flex_attention_backward.py_functionalize_impl895def flex_attention_backward_functionalize(896ctx: torch._subclasses.functional_tensor.BaseFunctionalizeAPI,897query: torch.Tensor,898key: torch.Tensor,899value: torch.Tensor,900out: torch.Tensor,901logsumexp: torch.Tensor,902grad_out: torch.Tensor,903grad_logsumexp: torch.Tensor,904fw_graph: Union[Callable, GraphModule],905joint_graph: GraphModule,906block_mask: Tuple,907scale: float,908kernel_options: Dict[str, Any],909score_mod_other_buffers: Tuple = (),910mask_mod_other_buffers: Tuple = (),911) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:912"""Defines the functionalization rules for the flex_attention operator.913
914Write now we are unwrapping each tensor and then redispatching to the next,
915since we know that the forward score mod function is assured to be free of mutations
916to the other_buffers, we skip that mutate check and go straight to redispatching.
917"""
918query_unwrapped = ctx.unwrap_tensors(query)919key_unwrapped = ctx.unwrap_tensors(key)920value_unwrapped = ctx.unwrap_tensors(value)921out_unwrapped = ctx.unwrap_tensors(out)922logsumexp_unwrapped = ctx.unwrap_tensors(logsumexp)923grad_out_unwrapped = ctx.unwrap_tensors(grad_out)924grad_logsumexp_unwrapped = ctx.unwrap_tensors(grad_logsumexp)925block_mask_unwrapped = ctx.unwrap_tensors(block_mask)926score_mod_other_buffers_unwrapped = ctx.unwrap_tensors(score_mod_other_buffers)927mask_mod_other_buffers_unwrapped = ctx.unwrap_tensors(mask_mod_other_buffers)928
929# Appease the mypy overlords930assert isinstance(query_unwrapped, torch.Tensor)931assert isinstance(key_unwrapped, torch.Tensor)932assert isinstance(value_unwrapped, torch.Tensor)933assert isinstance(out_unwrapped, torch.Tensor)934assert isinstance(logsumexp_unwrapped, torch.Tensor)935assert isinstance(grad_out_unwrapped, torch.Tensor)936assert isinstance(grad_logsumexp_unwrapped, torch.Tensor)937assert isinstance(block_mask_unwrapped, tuple)938assert isinstance(score_mod_other_buffers_unwrapped, tuple)939assert isinstance(mask_mod_other_buffers_unwrapped, tuple)940assert all(941isinstance(item, torch.Tensor)942for item in score_mod_other_buffers_unwrapped + mask_mod_other_buffers_unwrapped943)944
945with ctx.redispatch_to_next() as m:946functional_fw_graph = ctx.functionalize(fw_graph)947functional_joint_graph = ctx.functionalize(joint_graph)948
949grad_query, grad_key, grad_value = flex_attention_backward(950query_unwrapped,951key_unwrapped,952value_unwrapped,953out_unwrapped,954logsumexp_unwrapped,955grad_out_unwrapped,956grad_logsumexp_unwrapped,957functional_fw_graph, # type: ignore[arg-type]958functional_joint_graph, # type: ignore[arg-type]959block_mask_unwrapped,960scale,961kernel_options,962score_mod_other_buffers_unwrapped,963mask_mod_other_buffers_unwrapped,964)965
966return ctx.wrap_tensors((grad_query, grad_key, grad_value)) # type: ignore[return-value,arg-type]967
968
969@flex_attention_backward.py_impl(FakeTensorMode)970def flex_attention_backward_fake_tensor_mode(971mode: FakeTensorMode,972query: torch.Tensor,973key: torch.Tensor,974value: torch.Tensor,975out: torch.Tensor,976logsumexp: torch.Tensor,977grad_out: torch.Tensor,978grad_logsumexp: torch.Tensor,979fw_graph: Union[Callable, GraphModule],980joint_graph: GraphModule,981block_mask: Tuple,982scale: float,983kernel_options: Dict[str, Any],984score_mod_other_buffers: Tuple = (),985mask_mod_other_buffers: Tuple = (),986) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:987with mode:988grad_query = torch.empty_like(query)989grad_key = torch.empty_like(key)990grad_value = torch.empty_like(value)991return grad_query, grad_key, grad_value992
993
994flex_attention_backward.py_impl(DispatchKey.Autograd)(995autograd_not_implemented(flex_attention_backward, deferred_error=True)996)
997