style-transfer-pytorch
78 строк · 2.7 Кб
1"""Matrix square roots with backward passes.
2
3Cleaned up from https://github.com/msubhransu/matrix-sqrt.
4"""
5
6import torch
7
8
9def sqrtm_ns(a, num_iters=10):
10if a.ndim < 2:
11raise RuntimeError('tensor of matrices must have at least 2 dimensions')
12if a.shape[-2] != a.shape[-1]:
13raise RuntimeError('tensor must be batches of square matrices')
14if num_iters < 0:
15raise RuntimeError('num_iters must not be negative')
16norm_a = a.pow(2).sum(dim=[-2, -1], keepdim=True).sqrt()
17y = a / norm_a
18eye = torch.eye(a.shape[-1], device=a.device, dtype=a.dtype) * 3
19z = torch.eye(a.shape[-1], device=a.device, dtype=a.dtype)
20z = z.repeat([*a.shape[:-2], 1, 1])
21for i in range(num_iters):
22t = (eye - z @ y) / 2
23y = y @ t
24z = t @ z
25return y * norm_a.sqrt()
26
27
28class _MatrixSquareRootNSLyap(torch.autograd.Function):
29@staticmethod
30def forward(ctx, a, num_iters, num_iters_backward):
31z = sqrtm_ns(a, num_iters)
32ctx.save_for_backward(z, torch.tensor(num_iters_backward))
33return z
34
35@staticmethod
36def backward(ctx, grad_output):
37z, num_iters = ctx.saved_tensors
38norm_z = z.pow(2).sum(dim=[-2, -1], keepdim=True).sqrt()
39a = z / norm_z
40eye = torch.eye(z.shape[-1], device=z.device, dtype=z.dtype) * 3
41q = grad_output / norm_z
42for i in range(num_iters):
43eye_a_a = eye - a @ a
44q = q = (q @ eye_a_a - a.transpose(-2, -1) @ (a.transpose(-2, -1) @ q - q @ a)) / 2
45if i < num_iters - 1:
46a = a @ eye_a_a / 2
47return q / 2, None, None
48
49
50def sqrtm_ns_lyap(a, num_iters=10, num_iters_backward=None):
51if num_iters_backward is None:
52num_iters_backward = num_iters
53if num_iters_backward < 0:
54raise RuntimeError('num_iters_backward must not be negative')
55return _MatrixSquareRootNSLyap.apply(a, num_iters, num_iters_backward)
56
57
58class _MatrixSquareRootEig(torch.autograd.Function):
59@staticmethod
60def forward(ctx, a):
61vals, vecs = torch.linalg.eigh(a)
62ctx.save_for_backward(vals, vecs)
63return vecs @ vals.abs().sqrt().diag_embed() @ vecs.transpose(-2, -1)
64
65@staticmethod
66def backward(ctx, grad_output):
67vals, vecs = ctx.saved_tensors
68d = vals.abs().sqrt().unsqueeze(-1).repeat_interleave(vals.shape[-1], -1)
69vecs_t = vecs.transpose(-2, -1)
70return vecs @ (vecs_t @ grad_output @ vecs / (d + d.transpose(-2, -1))) @ vecs_t
71
72
73def sqrtm_eig(a):
74if a.ndim < 2:
75raise RuntimeError('tensor of matrices must have at least 2 dimensions')
76if a.shape[-2] != a.shape[-1]:
77raise RuntimeError('tensor must be batches of square matrices')
78return _MatrixSquareRootEig.apply(a)
79