stable-diffusion-webui

Форк
0
/
sd_hijack_optimizations.py 
675 строк · 24.0 Кб
1
from __future__ import annotations
2
import math
3
import psutil
4
import platform
5

6
import torch
7
from torch import einsum
8

9
from ldm.util import default
10
from einops import rearrange
11

12
from modules import shared, errors, devices, sub_quadratic_attention
13
from modules.hypernetworks import hypernetwork
14

15
import ldm.modules.attention
16
import ldm.modules.diffusionmodules.model
17

18
import sgm.modules.attention
19
import sgm.modules.diffusionmodules.model
20

21
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
22
sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward
23

24

25
class SdOptimization:
26
    name: str = None
27
    label: str | None = None
28
    cmd_opt: str | None = None
29
    priority: int = 0
30

31
    def title(self):
32
        if self.label is None:
33
            return self.name
34

35
        return f"{self.name} - {self.label}"
36

37
    def is_available(self):
38
        return True
39

40
    def apply(self):
41
        pass
42

43
    def undo(self):
44
        ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
45
        ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
46

47
        sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
48
        sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward
49

50

51
class SdOptimizationXformers(SdOptimization):
52
    name = "xformers"
53
    cmd_opt = "xformers"
54
    priority = 100
55

56
    def is_available(self):
57
        return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
58

59
    def apply(self):
60
        ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
61
        ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
62
        sgm.modules.attention.CrossAttention.forward = xformers_attention_forward
63
        sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
64

65

66
class SdOptimizationSdpNoMem(SdOptimization):
67
    name = "sdp-no-mem"
68
    label = "scaled dot product without memory efficient attention"
69
    cmd_opt = "opt_sdp_no_mem_attention"
70
    priority = 80
71

72
    def is_available(self):
73
        return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
74

75
    def apply(self):
76
        ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
77
        ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
78
        sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
79
        sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
80

81

82
class SdOptimizationSdp(SdOptimizationSdpNoMem):
83
    name = "sdp"
84
    label = "scaled dot product"
85
    cmd_opt = "opt_sdp_attention"
86
    priority = 70
87

88
    def apply(self):
89
        ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
90
        ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
91
        sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
92
        sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
93

94

95
class SdOptimizationSubQuad(SdOptimization):
96
    name = "sub-quadratic"
97
    cmd_opt = "opt_sub_quad_attention"
98

99
    @property
100
    def priority(self):
101
        return 1000 if shared.device.type == 'mps' else 10
102

103
    def apply(self):
104
        ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
105
        ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
106
        sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
107
        sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
108

109

110
class SdOptimizationV1(SdOptimization):
111
    name = "V1"
112
    label = "original v1"
113
    cmd_opt = "opt_split_attention_v1"
114
    priority = 10
115

116
    def apply(self):
117
        ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
118
        sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
119

120

121
class SdOptimizationInvokeAI(SdOptimization):
122
    name = "InvokeAI"
123
    cmd_opt = "opt_split_attention_invokeai"
124

125
    @property
126
    def priority(self):
127
        return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10
128

129
    def apply(self):
130
        ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
131
        sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
132

133

134
class SdOptimizationDoggettx(SdOptimization):
135
    name = "Doggettx"
136
    cmd_opt = "opt_split_attention"
137
    priority = 90
138

139
    def apply(self):
140
        ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
141
        ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
142
        sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward
143
        sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
144

145

146
def list_optimizers(res):
147
    res.extend([
148
        SdOptimizationXformers(),
149
        SdOptimizationSdpNoMem(),
150
        SdOptimizationSdp(),
151
        SdOptimizationSubQuad(),
152
        SdOptimizationV1(),
153
        SdOptimizationInvokeAI(),
154
        SdOptimizationDoggettx(),
155
    ])
156

157

158
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
159
    try:
160
        import xformers.ops
161
        shared.xformers_available = True
162
    except Exception:
163
        errors.report("Cannot import xformers", exc_info=True)
164

165

166
def get_available_vram():
167
    if shared.device.type == 'cuda':
168
        stats = torch.cuda.memory_stats(shared.device)
169
        mem_active = stats['active_bytes.all.current']
170
        mem_reserved = stats['reserved_bytes.all.current']
171
        mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
172
        mem_free_torch = mem_reserved - mem_active
173
        mem_free_total = mem_free_cuda + mem_free_torch
174
        return mem_free_total
175
    else:
176
        return psutil.virtual_memory().available
177

178

179
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
180
def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs):
181
    h = self.heads
182

183
    q_in = self.to_q(x)
184
    context = default(context, x)
185

186
    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
187
    k_in = self.to_k(context_k)
188
    v_in = self.to_v(context_v)
189
    del context, context_k, context_v, x
190

191
    q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
192
    del q_in, k_in, v_in
193

194
    dtype = q.dtype
195
    if shared.opts.upcast_attn:
196
        q, k, v = q.float(), k.float(), v.float()
197

198
    with devices.without_autocast(disable=not shared.opts.upcast_attn):
199
        r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
200
        for i in range(0, q.shape[0], 2):
201
            end = i + 2
202
            s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
203
            s1 *= self.scale
204

205
            s2 = s1.softmax(dim=-1)
206
            del s1
207

208
            r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
209
            del s2
210
        del q, k, v
211

212
    r1 = r1.to(dtype)
213

214
    r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
215
    del r1
216

217
    return self.to_out(r2)
218

219

220
# taken from https://github.com/Doggettx/stable-diffusion and modified
221
def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs):
222
    h = self.heads
223

224
    q_in = self.to_q(x)
225
    context = default(context, x)
226

227
    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
228
    k_in = self.to_k(context_k)
229
    v_in = self.to_v(context_v)
230

231
    dtype = q_in.dtype
232
    if shared.opts.upcast_attn:
233
        q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
234

235
    with devices.without_autocast(disable=not shared.opts.upcast_attn):
236
        k_in = k_in * self.scale
237

238
        del context, x
239

240
        q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
241
        del q_in, k_in, v_in
242

243
        r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
244

245
        mem_free_total = get_available_vram()
246

247
        gb = 1024 ** 3
248
        tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
249
        modifier = 3 if q.element_size() == 2 else 2.5
250
        mem_required = tensor_size * modifier
251
        steps = 1
252

253
        if mem_required > mem_free_total:
254
            steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
255
            # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
256
            #       f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
257

258
        if steps > 64:
259
            max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
260
            raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
261
                               f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
262

263
        slice_size = q.shape[1] // steps
264
        for i in range(0, q.shape[1], slice_size):
265
            end = min(i + slice_size, q.shape[1])
266
            s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
267

268
            s2 = s1.softmax(dim=-1, dtype=q.dtype)
269
            del s1
270

271
            r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
272
            del s2
273

274
        del q, k, v
275

276
    r1 = r1.to(dtype)
277

278
    r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
279
    del r1
280

281
    return self.to_out(r2)
282

283

284
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
285
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
286

287

288
def einsum_op_compvis(q, k, v):
289
    s = einsum('b i d, b j d -> b i j', q, k)
290
    s = s.softmax(dim=-1, dtype=s.dtype)
291
    return einsum('b i j, b j d -> b i d', s, v)
292

293

294
def einsum_op_slice_0(q, k, v, slice_size):
295
    r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
296
    for i in range(0, q.shape[0], slice_size):
297
        end = i + slice_size
298
        r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
299
    return r
300

301

302
def einsum_op_slice_1(q, k, v, slice_size):
303
    r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
304
    for i in range(0, q.shape[1], slice_size):
305
        end = i + slice_size
306
        r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
307
    return r
308

309

310
def einsum_op_mps_v1(q, k, v):
311
    if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
312
        return einsum_op_compvis(q, k, v)
313
    else:
314
        slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
315
        if slice_size % 4096 == 0:
316
            slice_size -= 1
317
        return einsum_op_slice_1(q, k, v, slice_size)
318

319

320
def einsum_op_mps_v2(q, k, v):
321
    if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
322
        return einsum_op_compvis(q, k, v)
323
    else:
324
        return einsum_op_slice_0(q, k, v, 1)
325

326

327
def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
328
    size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
329
    if size_mb <= max_tensor_mb:
330
        return einsum_op_compvis(q, k, v)
331
    div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
332
    if div <= q.shape[0]:
333
        return einsum_op_slice_0(q, k, v, q.shape[0] // div)
334
    return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
335

336

337
def einsum_op_cuda(q, k, v):
338
    stats = torch.cuda.memory_stats(q.device)
339
    mem_active = stats['active_bytes.all.current']
340
    mem_reserved = stats['reserved_bytes.all.current']
341
    mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
342
    mem_free_torch = mem_reserved - mem_active
343
    mem_free_total = mem_free_cuda + mem_free_torch
344
    # Divide factor of safety as there's copying and fragmentation
345
    return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
346

347

348
def einsum_op(q, k, v):
349
    if q.device.type == 'cuda':
350
        return einsum_op_cuda(q, k, v)
351

352
    if q.device.type == 'mps':
353
        if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
354
            return einsum_op_mps_v1(q, k, v)
355
        return einsum_op_mps_v2(q, k, v)
356

357
    # Smaller slices are faster due to L2/L3/SLC caches.
358
    # Tested on i7 with 8MB L3 cache.
359
    return einsum_op_tensor_mem(q, k, v, 32)
360

361

362
def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs):
363
    h = self.heads
364

365
    q = self.to_q(x)
366
    context = default(context, x)
367

368
    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
369
    k = self.to_k(context_k)
370
    v = self.to_v(context_v)
371
    del context, context_k, context_v, x
372

373
    dtype = q.dtype
374
    if shared.opts.upcast_attn:
375
        q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
376

377
    with devices.without_autocast(disable=not shared.opts.upcast_attn):
378
        k = k * self.scale
379

380
        q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
381
        r = einsum_op(q, k, v)
382
    r = r.to(dtype)
383
    return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
384

385
# -- End of code from https://github.com/invoke-ai/InvokeAI --
386

387

388
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
389
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
390
def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs):
391
    assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
392

393
    h = self.heads
394

395
    q = self.to_q(x)
396
    context = default(context, x)
397

398
    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
399
    k = self.to_k(context_k)
400
    v = self.to_v(context_v)
401
    del context, context_k, context_v, x
402

403
    q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
404
    k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
405
    v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
406

407
    if q.device.type == 'mps':
408
        q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
409

410
    dtype = q.dtype
411
    if shared.opts.upcast_attn:
412
        q, k = q.float(), k.float()
413

414
    x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
415

416
    x = x.to(dtype)
417

418
    x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
419

420
    out_proj, dropout = self.to_out
421
    x = out_proj(x)
422
    x = dropout(x)
423

424
    return x
425

426

427
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
428
    bytes_per_token = torch.finfo(q.dtype).bits//8
429
    batch_x_heads, q_tokens, _ = q.shape
430
    _, k_tokens, _ = k.shape
431
    qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
432

433
    if chunk_threshold is None:
434
        if q.device.type == 'mps':
435
            chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token)
436
        else:
437
            chunk_threshold_bytes = int(get_available_vram() * 0.7)
438
    elif chunk_threshold == 0:
439
        chunk_threshold_bytes = None
440
    else:
441
        chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
442

443
    if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
444
        kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
445
    elif kv_chunk_size_min == 0:
446
        kv_chunk_size_min = None
447

448
    if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
449
        # the big matmul fits into our memory limit; do everything in 1 chunk,
450
        # i.e. send it down the unchunked fast-path
451
        kv_chunk_size = k_tokens
452

453
    with devices.without_autocast(disable=q.dtype == v.dtype):
454
        return sub_quadratic_attention.efficient_dot_product_attention(
455
            q,
456
            k,
457
            v,
458
            query_chunk_size=q_chunk_size,
459
            kv_chunk_size=kv_chunk_size,
460
            kv_chunk_size_min = kv_chunk_size_min,
461
            use_checkpoint=use_checkpoint,
462
        )
463

464

465
def get_xformers_flash_attention_op(q, k, v):
466
    if not shared.cmd_opts.xformers_flash_attention:
467
        return None
468

469
    try:
470
        flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
471
        fw, bw = flash_attention_op
472
        if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
473
            return flash_attention_op
474
    except Exception as e:
475
        errors.display_once(e, "enabling flash attention")
476

477
    return None
478

479

480
def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
481
    h = self.heads
482
    q_in = self.to_q(x)
483
    context = default(context, x)
484

485
    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
486
    k_in = self.to_k(context_k)
487
    v_in = self.to_v(context_v)
488

489
    q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
490
    del q_in, k_in, v_in
491

492
    dtype = q.dtype
493
    if shared.opts.upcast_attn:
494
        q, k, v = q.float(), k.float(), v.float()
495

496
    out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
497

498
    out = out.to(dtype)
499

500
    out = rearrange(out, 'b n h d -> b n (h d)', h=h)
501
    return self.to_out(out)
502

503

504
# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
505
# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
506
def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs):
507
    batch_size, sequence_length, inner_dim = x.shape
508

509
    if mask is not None:
510
        mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
511
        mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
512

513
    h = self.heads
514
    q_in = self.to_q(x)
515
    context = default(context, x)
516

517
    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
518
    k_in = self.to_k(context_k)
519
    v_in = self.to_v(context_v)
520

521
    head_dim = inner_dim // h
522
    q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
523
    k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
524
    v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
525

526
    del q_in, k_in, v_in
527

528
    dtype = q.dtype
529
    if shared.opts.upcast_attn:
530
        q, k, v = q.float(), k.float(), v.float()
531

532
    # the output of sdp = (batch, num_heads, seq_len, head_dim)
533
    hidden_states = torch.nn.functional.scaled_dot_product_attention(
534
        q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
535
    )
536

537
    hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
538
    hidden_states = hidden_states.to(dtype)
539

540
    # linear proj
541
    hidden_states = self.to_out[0](hidden_states)
542
    # dropout
543
    hidden_states = self.to_out[1](hidden_states)
544
    return hidden_states
545

546

547
def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs):
548
    with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
549
        return scaled_dot_product_attention_forward(self, x, context, mask)
550

551

552
def cross_attention_attnblock_forward(self, x):
553
        h_ = x
554
        h_ = self.norm(h_)
555
        q1 = self.q(h_)
556
        k1 = self.k(h_)
557
        v = self.v(h_)
558

559
        # compute attention
560
        b, c, h, w = q1.shape
561

562
        q2 = q1.reshape(b, c, h*w)
563
        del q1
564

565
        q = q2.permute(0, 2, 1)   # b,hw,c
566
        del q2
567

568
        k = k1.reshape(b, c, h*w) # b,c,hw
569
        del k1
570

571
        h_ = torch.zeros_like(k, device=q.device)
572

573
        mem_free_total = get_available_vram()
574

575
        tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
576
        mem_required = tensor_size * 2.5
577
        steps = 1
578

579
        if mem_required > mem_free_total:
580
            steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
581

582
        slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
583
        for i in range(0, q.shape[1], slice_size):
584
            end = i + slice_size
585

586
            w1 = torch.bmm(q[:, i:end], k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
587
            w2 = w1 * (int(c)**(-0.5))
588
            del w1
589
            w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
590
            del w2
591

592
            # attend to values
593
            v1 = v.reshape(b, c, h*w)
594
            w4 = w3.permute(0, 2, 1)   # b,hw,hw (first hw of k, second of q)
595
            del w3
596

597
            h_[:, :, i:end] = torch.bmm(v1, w4)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
598
            del v1, w4
599

600
        h2 = h_.reshape(b, c, h, w)
601
        del h_
602

603
        h3 = self.proj_out(h2)
604
        del h2
605

606
        h3 += x
607

608
        return h3
609

610

611
def xformers_attnblock_forward(self, x):
612
    try:
613
        h_ = x
614
        h_ = self.norm(h_)
615
        q = self.q(h_)
616
        k = self.k(h_)
617
        v = self.v(h_)
618
        b, c, h, w = q.shape
619
        q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
620
        dtype = q.dtype
621
        if shared.opts.upcast_attn:
622
            q, k = q.float(), k.float()
623
        q = q.contiguous()
624
        k = k.contiguous()
625
        v = v.contiguous()
626
        out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
627
        out = out.to(dtype)
628
        out = rearrange(out, 'b (h w) c -> b c h w', h=h)
629
        out = self.proj_out(out)
630
        return x + out
631
    except NotImplementedError:
632
        return cross_attention_attnblock_forward(self, x)
633

634

635
def sdp_attnblock_forward(self, x):
636
    h_ = x
637
    h_ = self.norm(h_)
638
    q = self.q(h_)
639
    k = self.k(h_)
640
    v = self.v(h_)
641
    b, c, h, w = q.shape
642
    q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
643
    dtype = q.dtype
644
    if shared.opts.upcast_attn:
645
        q, k, v = q.float(), k.float(), v.float()
646
    q = q.contiguous()
647
    k = k.contiguous()
648
    v = v.contiguous()
649
    out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
650
    out = out.to(dtype)
651
    out = rearrange(out, 'b (h w) c -> b c h w', h=h)
652
    out = self.proj_out(out)
653
    return x + out
654

655

656
def sdp_no_mem_attnblock_forward(self, x):
657
    with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
658
        return sdp_attnblock_forward(self, x)
659

660

661
def sub_quad_attnblock_forward(self, x):
662
    h_ = x
663
    h_ = self.norm(h_)
664
    q = self.q(h_)
665
    k = self.k(h_)
666
    v = self.v(h_)
667
    b, c, h, w = q.shape
668
    q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
669
    q = q.contiguous()
670
    k = k.contiguous()
671
    v = v.contiguous()
672
    out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
673
    out = rearrange(out, 'b (h w) c -> b c h w', h=h)
674
    out = self.proj_out(out)
675
    return x + out
676

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.