style-transfer-pytorch

Форк
0
78 строк · 2.7 Кб
1
"""Matrix square roots with backward passes.
2

3
Cleaned up from https://github.com/msubhransu/matrix-sqrt.
4
"""
5

6
import torch
7

8

9
def sqrtm_ns(a, num_iters=10):
10
    if a.ndim < 2:
11
        raise RuntimeError('tensor of matrices must have at least 2 dimensions')
12
    if a.shape[-2] != a.shape[-1]:
13
        raise RuntimeError('tensor must be batches of square matrices')
14
    if num_iters < 0:
15
        raise RuntimeError('num_iters must not be negative')
16
    norm_a = a.pow(2).sum(dim=[-2, -1], keepdim=True).sqrt()
17
    y = a / norm_a
18
    eye = torch.eye(a.shape[-1], device=a.device, dtype=a.dtype) * 3
19
    z = torch.eye(a.shape[-1], device=a.device, dtype=a.dtype)
20
    z = z.repeat([*a.shape[:-2], 1, 1])
21
    for i in range(num_iters):
22
        t = (eye - z @ y) / 2
23
        y = y @ t
24
        z = t @ z
25
    return y * norm_a.sqrt()
26

27

28
class _MatrixSquareRootNSLyap(torch.autograd.Function):
29
    @staticmethod
30
    def forward(ctx, a, num_iters, num_iters_backward):
31
        z = sqrtm_ns(a, num_iters)
32
        ctx.save_for_backward(z, torch.tensor(num_iters_backward))
33
        return z
34

35
    @staticmethod
36
    def backward(ctx, grad_output):
37
        z, num_iters = ctx.saved_tensors
38
        norm_z = z.pow(2).sum(dim=[-2, -1], keepdim=True).sqrt()
39
        a = z / norm_z
40
        eye = torch.eye(z.shape[-1], device=z.device, dtype=z.dtype) * 3
41
        q = grad_output / norm_z
42
        for i in range(num_iters):
43
            eye_a_a = eye - a @ a
44
            q = q = (q @ eye_a_a - a.transpose(-2, -1) @ (a.transpose(-2, -1) @ q - q @ a)) / 2
45
            if i < num_iters - 1:
46
                a = a @ eye_a_a / 2
47
        return q / 2, None, None
48

49

50
def sqrtm_ns_lyap(a, num_iters=10, num_iters_backward=None):
51
    if num_iters_backward is None:
52
        num_iters_backward = num_iters
53
    if num_iters_backward < 0:
54
        raise RuntimeError('num_iters_backward must not be negative')
55
    return _MatrixSquareRootNSLyap.apply(a, num_iters, num_iters_backward)
56

57

58
class _MatrixSquareRootEig(torch.autograd.Function):
59
    @staticmethod
60
    def forward(ctx, a):
61
        vals, vecs = torch.linalg.eigh(a)
62
        ctx.save_for_backward(vals, vecs)
63
        return vecs @ vals.abs().sqrt().diag_embed() @ vecs.transpose(-2, -1)
64

65
    @staticmethod
66
    def backward(ctx, grad_output):
67
        vals, vecs = ctx.saved_tensors
68
        d = vals.abs().sqrt().unsqueeze(-1).repeat_interleave(vals.shape[-1], -1)
69
        vecs_t = vecs.transpose(-2, -1)
70
        return vecs @ (vecs_t @ grad_output @ vecs / (d + d.transpose(-2, -1))) @ vecs_t
71

72

73
def sqrtm_eig(a):
74
    if a.ndim < 2:
75
        raise RuntimeError('tensor of matrices must have at least 2 dimensions')
76
    if a.shape[-2] != a.shape[-1]:
77
        raise RuntimeError('tensor must be batches of square matrices')
78
    return _MatrixSquareRootEig.apply(a)
79

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.