stable-diffusion-webui
675 строк · 24.0 Кб
1from __future__ import annotations2import math3import psutil4import platform5
6import torch7from torch import einsum8
9from ldm.util import default10from einops import rearrange11
12from modules import shared, errors, devices, sub_quadratic_attention13from modules.hypernetworks import hypernetwork14
15import ldm.modules.attention16import ldm.modules.diffusionmodules.model17
18import sgm.modules.attention19import sgm.modules.diffusionmodules.model20
21diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward22sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward23
24
25class SdOptimization:26name: str = None27label: str | None = None28cmd_opt: str | None = None29priority: int = 030
31def title(self):32if self.label is None:33return self.name34
35return f"{self.name} - {self.label}"36
37def is_available(self):38return True39
40def apply(self):41pass42
43def undo(self):44ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward45ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward46
47sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward48sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward49
50
51class SdOptimizationXformers(SdOptimization):52name = "xformers"53cmd_opt = "xformers"54priority = 10055
56def is_available(self):57return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))58
59def apply(self):60ldm.modules.attention.CrossAttention.forward = xformers_attention_forward61ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward62sgm.modules.attention.CrossAttention.forward = xformers_attention_forward63sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward64
65
66class SdOptimizationSdpNoMem(SdOptimization):67name = "sdp-no-mem"68label = "scaled dot product without memory efficient attention"69cmd_opt = "opt_sdp_no_mem_attention"70priority = 8071
72def is_available(self):73return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)74
75def apply(self):76ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward77ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward78sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward79sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward80
81
82class SdOptimizationSdp(SdOptimizationSdpNoMem):83name = "sdp"84label = "scaled dot product"85cmd_opt = "opt_sdp_attention"86priority = 7087
88def apply(self):89ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward90ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward91sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward92sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward93
94
95class SdOptimizationSubQuad(SdOptimization):96name = "sub-quadratic"97cmd_opt = "opt_sub_quad_attention"98
99@property100def priority(self):101return 1000 if shared.device.type == 'mps' else 10102
103def apply(self):104ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward105ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward106sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward107sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward108
109
110class SdOptimizationV1(SdOptimization):111name = "V1"112label = "original v1"113cmd_opt = "opt_split_attention_v1"114priority = 10115
116def apply(self):117ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1118sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1119
120
121class SdOptimizationInvokeAI(SdOptimization):122name = "InvokeAI"123cmd_opt = "opt_split_attention_invokeai"124
125@property126def priority(self):127return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10128
129def apply(self):130ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI131sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI132
133
134class SdOptimizationDoggettx(SdOptimization):135name = "Doggettx"136cmd_opt = "opt_split_attention"137priority = 90138
139def apply(self):140ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward141ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward142sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward143sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward144
145
146def list_optimizers(res):147res.extend([148SdOptimizationXformers(),149SdOptimizationSdpNoMem(),150SdOptimizationSdp(),151SdOptimizationSubQuad(),152SdOptimizationV1(),153SdOptimizationInvokeAI(),154SdOptimizationDoggettx(),155])156
157
158if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:159try:160import xformers.ops161shared.xformers_available = True162except Exception:163errors.report("Cannot import xformers", exc_info=True)164
165
166def get_available_vram():167if shared.device.type == 'cuda':168stats = torch.cuda.memory_stats(shared.device)169mem_active = stats['active_bytes.all.current']170mem_reserved = stats['reserved_bytes.all.current']171mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())172mem_free_torch = mem_reserved - mem_active173mem_free_total = mem_free_cuda + mem_free_torch174return mem_free_total175else:176return psutil.virtual_memory().available177
178
179# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
180def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs):181h = self.heads182
183q_in = self.to_q(x)184context = default(context, x)185
186context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)187k_in = self.to_k(context_k)188v_in = self.to_v(context_v)189del context, context_k, context_v, x190
191q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))192del q_in, k_in, v_in193
194dtype = q.dtype195if shared.opts.upcast_attn:196q, k, v = q.float(), k.float(), v.float()197
198with devices.without_autocast(disable=not shared.opts.upcast_attn):199r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)200for i in range(0, q.shape[0], 2):201end = i + 2202s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])203s1 *= self.scale204
205s2 = s1.softmax(dim=-1)206del s1207
208r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])209del s2210del q, k, v211
212r1 = r1.to(dtype)213
214r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)215del r1216
217return self.to_out(r2)218
219
220# taken from https://github.com/Doggettx/stable-diffusion and modified
221def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs):222h = self.heads223
224q_in = self.to_q(x)225context = default(context, x)226
227context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)228k_in = self.to_k(context_k)229v_in = self.to_v(context_v)230
231dtype = q_in.dtype232if shared.opts.upcast_attn:233q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()234
235with devices.without_autocast(disable=not shared.opts.upcast_attn):236k_in = k_in * self.scale237
238del context, x239
240q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))241del q_in, k_in, v_in242
243r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)244
245mem_free_total = get_available_vram()246
247gb = 1024 ** 3248tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()249modifier = 3 if q.element_size() == 2 else 2.5250mem_required = tensor_size * modifier251steps = 1252
253if mem_required > mem_free_total:254steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))255# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "256# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")257
258if steps > 64:259max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64260raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '261f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')262
263slice_size = q.shape[1] // steps264for i in range(0, q.shape[1], slice_size):265end = min(i + slice_size, q.shape[1])266s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)267
268s2 = s1.softmax(dim=-1, dtype=q.dtype)269del s1270
271r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)272del s2273
274del q, k, v275
276r1 = r1.to(dtype)277
278r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)279del r1280
281return self.to_out(r2)282
283
284# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
285mem_total_gb = psutil.virtual_memory().total // (1 << 30)286
287
288def einsum_op_compvis(q, k, v):289s = einsum('b i d, b j d -> b i j', q, k)290s = s.softmax(dim=-1, dtype=s.dtype)291return einsum('b i j, b j d -> b i d', s, v)292
293
294def einsum_op_slice_0(q, k, v, slice_size):295r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)296for i in range(0, q.shape[0], slice_size):297end = i + slice_size298r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])299return r300
301
302def einsum_op_slice_1(q, k, v, slice_size):303r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)304for i in range(0, q.shape[1], slice_size):305end = i + slice_size306r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)307return r308
309
310def einsum_op_mps_v1(q, k, v):311if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096312return einsum_op_compvis(q, k, v)313else:314slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))315if slice_size % 4096 == 0:316slice_size -= 1317return einsum_op_slice_1(q, k, v, slice_size)318
319
320def einsum_op_mps_v2(q, k, v):321if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:322return einsum_op_compvis(q, k, v)323else:324return einsum_op_slice_0(q, k, v, 1)325
326
327def einsum_op_tensor_mem(q, k, v, max_tensor_mb):328size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)329if size_mb <= max_tensor_mb:330return einsum_op_compvis(q, k, v)331div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()332if div <= q.shape[0]:333return einsum_op_slice_0(q, k, v, q.shape[0] // div)334return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))335
336
337def einsum_op_cuda(q, k, v):338stats = torch.cuda.memory_stats(q.device)339mem_active = stats['active_bytes.all.current']340mem_reserved = stats['reserved_bytes.all.current']341mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)342mem_free_torch = mem_reserved - mem_active343mem_free_total = mem_free_cuda + mem_free_torch344# Divide factor of safety as there's copying and fragmentation345return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))346
347
348def einsum_op(q, k, v):349if q.device.type == 'cuda':350return einsum_op_cuda(q, k, v)351
352if q.device.type == 'mps':353if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:354return einsum_op_mps_v1(q, k, v)355return einsum_op_mps_v2(q, k, v)356
357# Smaller slices are faster due to L2/L3/SLC caches.358# Tested on i7 with 8MB L3 cache.359return einsum_op_tensor_mem(q, k, v, 32)360
361
362def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs):363h = self.heads364
365q = self.to_q(x)366context = default(context, x)367
368context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)369k = self.to_k(context_k)370v = self.to_v(context_v)371del context, context_k, context_v, x372
373dtype = q.dtype374if shared.opts.upcast_attn:375q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()376
377with devices.without_autocast(disable=not shared.opts.upcast_attn):378k = k * self.scale379
380q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))381r = einsum_op(q, k, v)382r = r.to(dtype)383return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))384
385# -- End of code from https://github.com/invoke-ai/InvokeAI --
386
387
388# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
389# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
390def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs):391assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."392
393h = self.heads394
395q = self.to_q(x)396context = default(context, x)397
398context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)399k = self.to_k(context_k)400v = self.to_v(context_v)401del context, context_k, context_v, x402
403q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)404k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)405v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)406
407if q.device.type == 'mps':408q, k, v = q.contiguous(), k.contiguous(), v.contiguous()409
410dtype = q.dtype411if shared.opts.upcast_attn:412q, k = q.float(), k.float()413
414x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)415
416x = x.to(dtype)417
418x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)419
420out_proj, dropout = self.to_out421x = out_proj(x)422x = dropout(x)423
424return x425
426
427def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):428bytes_per_token = torch.finfo(q.dtype).bits//8429batch_x_heads, q_tokens, _ = q.shape430_, k_tokens, _ = k.shape431qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens432
433if chunk_threshold is None:434if q.device.type == 'mps':435chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token)436else:437chunk_threshold_bytes = int(get_available_vram() * 0.7)438elif chunk_threshold == 0:439chunk_threshold_bytes = None440else:441chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())442
443if kv_chunk_size_min is None and chunk_threshold_bytes is not None:444kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))445elif kv_chunk_size_min == 0:446kv_chunk_size_min = None447
448if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:449# the big matmul fits into our memory limit; do everything in 1 chunk,450# i.e. send it down the unchunked fast-path451kv_chunk_size = k_tokens452
453with devices.without_autocast(disable=q.dtype == v.dtype):454return sub_quadratic_attention.efficient_dot_product_attention(455q,456k,457v,458query_chunk_size=q_chunk_size,459kv_chunk_size=kv_chunk_size,460kv_chunk_size_min = kv_chunk_size_min,461use_checkpoint=use_checkpoint,462)463
464
465def get_xformers_flash_attention_op(q, k, v):466if not shared.cmd_opts.xformers_flash_attention:467return None468
469try:470flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp471fw, bw = flash_attention_op472if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):473return flash_attention_op474except Exception as e:475errors.display_once(e, "enabling flash attention")476
477return None478
479
480def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):481h = self.heads482q_in = self.to_q(x)483context = default(context, x)484
485context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)486k_in = self.to_k(context_k)487v_in = self.to_v(context_v)488
489q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))490del q_in, k_in, v_in491
492dtype = q.dtype493if shared.opts.upcast_attn:494q, k, v = q.float(), k.float(), v.float()495
496out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))497
498out = out.to(dtype)499
500out = rearrange(out, 'b n h d -> b n (h d)', h=h)501return self.to_out(out)502
503
504# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
505# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
506def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs):507batch_size, sequence_length, inner_dim = x.shape508
509if mask is not None:510mask = self.prepare_attention_mask(mask, sequence_length, batch_size)511mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])512
513h = self.heads514q_in = self.to_q(x)515context = default(context, x)516
517context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)518k_in = self.to_k(context_k)519v_in = self.to_v(context_v)520
521head_dim = inner_dim // h522q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)523k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)524v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)525
526del q_in, k_in, v_in527
528dtype = q.dtype529if shared.opts.upcast_attn:530q, k, v = q.float(), k.float(), v.float()531
532# the output of sdp = (batch, num_heads, seq_len, head_dim)533hidden_states = torch.nn.functional.scaled_dot_product_attention(534q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False535)536
537hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)538hidden_states = hidden_states.to(dtype)539
540# linear proj541hidden_states = self.to_out[0](hidden_states)542# dropout543hidden_states = self.to_out[1](hidden_states)544return hidden_states545
546
547def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs):548with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):549return scaled_dot_product_attention_forward(self, x, context, mask)550
551
552def cross_attention_attnblock_forward(self, x):553h_ = x554h_ = self.norm(h_)555q1 = self.q(h_)556k1 = self.k(h_)557v = self.v(h_)558
559# compute attention560b, c, h, w = q1.shape561
562q2 = q1.reshape(b, c, h*w)563del q1564
565q = q2.permute(0, 2, 1) # b,hw,c566del q2567
568k = k1.reshape(b, c, h*w) # b,c,hw569del k1570
571h_ = torch.zeros_like(k, device=q.device)572
573mem_free_total = get_available_vram()574
575tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()576mem_required = tensor_size * 2.5577steps = 1578
579if mem_required > mem_free_total:580steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))581
582slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]583for i in range(0, q.shape[1], slice_size):584end = i + slice_size585
586w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]587w2 = w1 * (int(c)**(-0.5))588del w1589w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)590del w2591
592# attend to values593v1 = v.reshape(b, c, h*w)594w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)595del w3596
597h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]598del v1, w4599
600h2 = h_.reshape(b, c, h, w)601del h_602
603h3 = self.proj_out(h2)604del h2605
606h3 += x607
608return h3609
610
611def xformers_attnblock_forward(self, x):612try:613h_ = x614h_ = self.norm(h_)615q = self.q(h_)616k = self.k(h_)617v = self.v(h_)618b, c, h, w = q.shape619q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))620dtype = q.dtype621if shared.opts.upcast_attn:622q, k = q.float(), k.float()623q = q.contiguous()624k = k.contiguous()625v = v.contiguous()626out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))627out = out.to(dtype)628out = rearrange(out, 'b (h w) c -> b c h w', h=h)629out = self.proj_out(out)630return x + out631except NotImplementedError:632return cross_attention_attnblock_forward(self, x)633
634
635def sdp_attnblock_forward(self, x):636h_ = x637h_ = self.norm(h_)638q = self.q(h_)639k = self.k(h_)640v = self.v(h_)641b, c, h, w = q.shape642q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))643dtype = q.dtype644if shared.opts.upcast_attn:645q, k, v = q.float(), k.float(), v.float()646q = q.contiguous()647k = k.contiguous()648v = v.contiguous()649out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)650out = out.to(dtype)651out = rearrange(out, 'b (h w) c -> b c h w', h=h)652out = self.proj_out(out)653return x + out654
655
656def sdp_no_mem_attnblock_forward(self, x):657with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):658return sdp_attnblock_forward(self, x)659
660
661def sub_quad_attnblock_forward(self, x):662h_ = x663h_ = self.norm(h_)664q = self.q(h_)665k = self.k(h_)666v = self.v(h_)667b, c, h, w = q.shape668q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))669q = q.contiguous()670k = k.contiguous()671v = v.contiguous()672out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)673out = rearrange(out, 'b (h w) c -> b c h w', h=h)674out = self.proj_out(out)675return x + out676