2
"""Locally Optimal Block Preconditioned Conjugate Gradient methods."""
6
from typing import Dict, Optional, Tuple
9
from torch import _linalg_utils as _utils, Tensor
10
from torch.overrides import handle_torch_function, has_torch_function
16
def _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U):
18
F = D.unsqueeze(-2) - D.unsqueeze(-1)
19
F.diagonal(dim1=-2, dim2=-1).fill_(float("inf"))
23
Ut = U.mT.contiguous()
25
U, torch.matmul(torch.diag_embed(D_grad) + torch.matmul(Ut, U_grad) * F, Ut)
31
def _polynomial_coefficients_given_roots(roots):
33
Given the `roots` of a polynomial, find the polynomial's coefficients.
35
If roots = (r_1, ..., r_n), then the method returns
36
coefficients (a_0, a_1, ..., a_n (== 1)) so that
37
p(x) = (x - r_1) * ... * (x - r_n)
38
= x^n + a_{n-1} * x^{n-1} + ... a_1 * x_1 + a_0
40
Note: for better performance requires writing a low-level kernel
42
poly_order = roots.shape[-1]
43
poly_coeffs_shape = list(roots.shape)
47
poly_coeffs_shape[-1] += 2
48
poly_coeffs = roots.new_zeros(poly_coeffs_shape)
49
poly_coeffs[..., 0] = 1
50
poly_coeffs[..., -1] = 1
53
for i in range(1, poly_order + 1):
63
poly_coeffs_new = poly_coeffs.clone() if roots.requires_grad else poly_coeffs
64
out = poly_coeffs_new.narrow(-1, poly_order - i, i + 1)
65
out -= roots.narrow(-1, i - 1, 1) * poly_coeffs.narrow(
66
-1, poly_order - i + 1, i + 1
68
poly_coeffs = poly_coeffs_new
70
return poly_coeffs.narrow(-1, 1, poly_order + 1)
73
def _polynomial_value(poly, x, zero_power, transition):
75
A generic method for computing poly(x) using the Horner's rule.
78
poly (Tensor): the (possibly batched) 1D Tensor representing
79
polynomial coefficients such that
80
poly[..., i] = (a_{i_0}, ..., a{i_n} (==1)), and
81
poly(x) = poly[..., 0] * zero_power + ... + poly[..., n] * x^n
83
x (Tensor): the value (possible batched) to evalate the polynomial `poly` at.
85
zero_power (Tensor): the representation of `x^0`. It is application-specific.
87
transition (Callable): the function that accepts some intermediate result `int_val`,
88
the `x` and a specific polynomial coefficient
89
`poly[..., k]` for some iteration `k`.
90
It basically performs one iteration of the Horner's rule
91
defined as `x * int_val + poly[..., k] * zero_power`.
92
Note that `zero_power` is not a parameter,
93
because the step `+ poly[..., k] * zero_power` depends on `x`,
94
whether it is a vector, a matrix, or something else, so this
95
functionality is delegated to the user.
98
res = zero_power.clone()
99
for k in range(poly.size(-1) - 2, -1, -1):
100
res = transition(res, x, poly[..., k])
104
def _matrix_polynomial_value(poly, x, zero_power=None):
106
Evaluates `poly(x)` for the (batched) matrix input `x`.
107
Check out `_polynomial_value` function for more details.
111
def transition(curr_poly_val, x, poly_coeff):
112
res = x.matmul(curr_poly_val)
113
res.diagonal(dim1=-2, dim2=-1).add_(poly_coeff.unsqueeze(-1))
116
if zero_power is None:
117
zero_power = torch.eye(
118
x.size(-1), x.size(-1), dtype=x.dtype, device=x.device
119
).view(*([1] * len(list(x.shape[:-2]))), x.size(-1), x.size(-1))
121
return _polynomial_value(poly, x, zero_power, transition)
124
def _vector_polynomial_value(poly, x, zero_power=None):
126
Evaluates `poly(x)` for the (batched) vector input `x`.
127
Check out `_polynomial_value` function for more details.
131
def transition(curr_poly_val, x, poly_coeff):
132
res = torch.addcmul(poly_coeff.unsqueeze(-1), x, curr_poly_val)
135
if zero_power is None:
136
zero_power = x.new_ones(1).expand(x.shape)
138
return _polynomial_value(poly, x, zero_power, transition)
141
def _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest):
144
Ut = U.mT.contiguous()
145
proj_U_ortho = -U.matmul(Ut)
146
proj_U_ortho.diagonal(dim1=-2, dim2=-1).add_(1)
153
gen = torch.Generator(A.device)
156
U_ortho = proj_U_ortho.matmul(
158
(*A.shape[:-1], A.size(-1) - D.size(-1)),
164
U_ortho_t = U_ortho.mT.contiguous()
169
chr_poly_D = _polynomial_coefficients_given_roots(D)
199
U_grad_projected = U_grad
200
series_acc = U_grad_projected.new_zeros(U_grad_projected.shape)
201
for k in range(1, chr_poly_D.size(-1)):
202
poly_D = _vector_polynomial_value(chr_poly_D[..., k:], D)
203
series_acc += U_grad_projected * poly_D.unsqueeze(-2)
204
U_grad_projected = A.matmul(U_grad_projected)
213
chr_poly_D_at_A = _matrix_polynomial_value(chr_poly_D, A)
216
chr_poly_D_at_A_to_U_ortho = torch.matmul(
217
U_ortho_t, torch.matmul(chr_poly_D_at_A, U_ortho)
228
chr_poly_D_at_A_to_U_ortho_sign = -1 if (largest and (k % 2 == 1)) else +1
229
chr_poly_D_at_A_to_U_ortho_L = torch.linalg.cholesky(
230
chr_poly_D_at_A_to_U_ortho_sign * chr_poly_D_at_A_to_U_ortho
234
res = _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U)
238
res -= U_ortho.matmul(
239
chr_poly_D_at_A_to_U_ortho_sign
240
* torch.cholesky_solve(
241
U_ortho_t.matmul(series_acc), chr_poly_D_at_A_to_U_ortho_L
248
def _symeig_backward(D_grad, U_grad, A, D, U, largest):
250
if U.size(-1) == U.size(-2):
251
return _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U)
253
return _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest)
256
class LOBPCGAutogradFunction(torch.autograd.Function):
261
k: Optional[int] = None,
262
B: Optional[Tensor] = None,
263
X: Optional[Tensor] = None,
264
n: Optional[int] = None,
265
iK: Optional[Tensor] = None,
266
niter: Optional[int] = None,
267
tol: Optional[float] = None,
268
largest: Optional[bool] = None,
269
method: Optional[str] = None,
270
tracker: None = None,
271
ortho_iparams: Optional[Dict[str, int]] = None,
272
ortho_fparams: Optional[Dict[str, float]] = None,
273
ortho_bparams: Optional[Dict[str, bool]] = None,
274
) -> Tuple[Tensor, Tensor]:
277
A = A.contiguous() if (not A.is_sparse) else A
279
B = B.contiguous() if (not B.is_sparse) else B
298
ctx.save_for_backward(A, B, D, U)
299
ctx.largest = largest
304
def backward(ctx, D_grad, U_grad):
305
A_grad = B_grad = None
308
A, B, D, U = ctx.saved_tensors
309
largest = ctx.largest
312
if A.is_sparse or (B is not None and B.is_sparse and ctx.needs_input_grad[2]):
314
"lobpcg.backward does not support sparse input yet."
315
"Note that lobpcg.forward does though."
318
A.dtype in (torch.complex64, torch.complex128)
320
and B.dtype in (torch.complex64, torch.complex128)
323
"lobpcg.backward does not support complex input yet."
324
"Note that lobpcg.forward does though."
328
"lobpcg.backward does not support backward with B != I yet."
336
A_grad = _symeig_backward(D_grad, U_grad, A, D, U, largest)
347
k: Optional[int] = None,
348
B: Optional[Tensor] = None,
349
X: Optional[Tensor] = None,
350
n: Optional[int] = None,
351
iK: Optional[Tensor] = None,
352
niter: Optional[int] = None,
353
tol: Optional[float] = None,
354
largest: Optional[bool] = None,
355
method: Optional[str] = None,
356
tracker: None = None,
357
ortho_iparams: Optional[Dict[str, int]] = None,
358
ortho_fparams: Optional[Dict[str, float]] = None,
359
ortho_bparams: Optional[Dict[str, bool]] = None,
360
) -> Tuple[Tensor, Tensor]:
361
"""Find the k largest (or smallest) eigenvalues and the corresponding
362
eigenvectors of a symmetric positive definite generalized
363
eigenvalue problem using matrix-free LOBPCG methods.
365
This function is a front-end to the following LOBPCG algorithms
366
selectable via `method` argument:
368
`method="basic"` - the LOBPCG method introduced by Andrew
369
Knyazev, see [Knyazev2001]. A less robust method, may fail when
370
Cholesky is applied to singular input.
372
`method="ortho"` - the LOBPCG method with orthogonal basis
373
selection [StathopoulosEtal2002]. A robust method.
375
Supported inputs are dense, sparse, and batches of dense matrices.
377
.. note:: In general, the basic method spends least time per
378
iteration. However, the robust methods converge much faster and
379
are more stable. So, the usage of the basic method is generally
380
not recommended but there exist cases where the usage of the
381
basic method may be preferred.
383
.. warning:: The backward method does not support sparse and complex inputs.
384
It works only when `B` is not provided (i.e. `B == None`).
385
We are actively working on extensions, and the details of
386
the algorithms are going to be published promptly.
388
.. warning:: While it is assumed that `A` is symmetric, `A.grad` is not.
389
To make sure that `A.grad` is symmetric, so that `A - t * A.grad` is symmetric
390
in first-order optimization routines, prior to running `lobpcg`
391
we do the following symmetrization map: `A -> (A + A.t()) / 2`.
392
The map is performed only when the `A` requires gradients.
396
A (Tensor): the input tensor of size :math:`(*, m, m)`
398
B (Tensor, optional): the input tensor of size :math:`(*, m,
399
m)`. When not specified, `B` is interpreted as
402
X (tensor, optional): the input tensor of size :math:`(*, m, n)`
403
where `k <= n <= m`. When specified, it is used as
404
initial approximation of eigenvectors. X must be a
407
iK (tensor, optional): the input tensor of size :math:`(*, m,
408
m)`. When specified, it will be used as preconditioner.
410
k (integer, optional): the number of requested
411
eigenpairs. Default is the number of :math:`X`
412
columns (when specified) or `1`.
414
n (integer, optional): if :math:`X` is not specified then `n`
415
specifies the size of the generated random
416
approximation of eigenvectors. Default value for `n`
417
is `k`. If :math:`X` is specified, the value of `n`
418
(when specified) must be the number of :math:`X`
421
tol (float, optional): residual tolerance for stopping
422
criterion. Default is `feps ** 0.5` where `feps` is
423
smallest non-zero floating-point number of the given
424
input tensor `A` data type.
426
largest (bool, optional): when True, solve the eigenproblem for
427
the largest eigenvalues. Otherwise, solve the
428
eigenproblem for smallest eigenvalues. Default is
431
method (str, optional): select LOBPCG method. See the
432
description of the function above. Default is
435
niter (int, optional): maximum number of iterations. When
436
reached, the iteration process is hard-stopped and
437
the current approximation of eigenpairs is returned.
438
For infinite iteration but until convergence criteria
441
tracker (callable, optional) : a function for tracing the
442
iteration process. When specified, it is called at
443
each iteration step with LOBPCG instance as an
444
argument. The LOBPCG instance holds the full state of
445
the iteration process in the following attributes:
447
`iparams`, `fparams`, `bparams` - dictionaries of
448
integer, float, and boolean valued input
449
parameters, respectively
451
`ivars`, `fvars`, `bvars`, `tvars` - dictionaries
452
of integer, float, boolean, and Tensor valued
453
iteration variables, respectively.
455
`A`, `B`, `iK` - input Tensor arguments.
457
`E`, `X`, `S`, `R` - iteration Tensor variables.
461
`ivars["istep"]` - the current iteration step
462
`X` - the current approximation of eigenvectors
463
`E` - the current approximation of eigenvalues
464
`R` - the current residual
465
`ivars["converged_count"]` - the current number of converged eigenpairs
466
`tvars["rerr"]` - the current state of convergence criteria
468
Note that when `tracker` stores Tensor objects from
469
the LOBPCG instance, it must make copies of these.
471
If `tracker` sets `bvars["force_stop"] = True`, the
472
iteration process will be hard-stopped.
474
ortho_iparams, ortho_fparams, ortho_bparams (dict, optional):
475
various parameters to LOBPCG algorithm when using
480
E (Tensor): tensor of eigenvalues of size :math:`(*, k)`
482
X (Tensor): tensor of eigenvectors of size :math:`(*, m, k)`
486
[Knyazev2001] Andrew V. Knyazev. (2001) Toward the Optimal
487
Preconditioned Eigensolver: Locally Optimal Block Preconditioned
488
Conjugate Gradient Method. SIAM J. Sci. Comput., 23(2),
490
https://epubs.siam.org/doi/abs/10.1137/S1064827500366124
492
[StathopoulosEtal2002] Andreas Stathopoulos and Kesheng
493
Wu. (2002) A Block Orthogonalization Procedure with Constant
494
Synchronization Requirements. SIAM J. Sci. Comput., 23(6),
495
2165-2182. (18 pages)
496
https://epubs.siam.org/doi/10.1137/S1064827500370883
498
[DuerschEtal2018] Jed A. Duersch, Meiyue Shao, Chao Yang, Ming
499
Gu. (2018) A Robust and Efficient Implementation of LOBPCG.
500
SIAM J. Sci. Comput., 40(5), C655-C676. (22 pages)
501
https://epubs.siam.org/doi/abs/10.1137/17M1129830
505
if not torch.jit.is_scripting():
506
tensor_ops = (A, B, X, iK)
507
if not set(map(type, tensor_ops)).issubset(
508
(torch.Tensor, type(None))
509
) and has_torch_function(tensor_ops):
510
return handle_torch_function(
524
ortho_iparams=ortho_iparams,
525
ortho_fparams=ortho_fparams,
526
ortho_bparams=ortho_bparams,
529
if not torch._jit_internal.is_scripting():
530
if A.requires_grad or (B is not None and B.requires_grad):
537
A_sym = (A + A.mT) / 2
538
B_sym = (B + B.mT) / 2 if (B is not None) else None
540
return LOBPCGAutogradFunction.apply(
557
if A.requires_grad or (B is not None and B.requires_grad):
559
"Script and require grads is not supported atm."
560
"If you just want to do the forward, use .detach()"
561
"on A and B before calling into lobpcg"
584
k: Optional[int] = None,
585
B: Optional[Tensor] = None,
586
X: Optional[Tensor] = None,
587
n: Optional[int] = None,
588
iK: Optional[Tensor] = None,
589
niter: Optional[int] = None,
590
tol: Optional[float] = None,
591
largest: Optional[bool] = None,
592
method: Optional[str] = None,
593
tracker: None = None,
594
ortho_iparams: Optional[Dict[str, int]] = None,
595
ortho_fparams: Optional[Dict[str, float]] = None,
596
ortho_bparams: Optional[Dict[str, bool]] = None,
597
) -> Tuple[Tensor, Tensor]:
599
assert A.shape[-2] == A.shape[-1], A.shape
602
assert A.shape == B.shape, (A.shape, B.shape)
604
dtype = _utils.get_floating_dtype(A)
607
feps = {torch.float32: 1.2e-07, torch.float64: 2.23e-16}[dtype]
611
k = (1 if X is None else X.shape[-1]) if k is None else k
612
n = (k if n is None else n) if X is None else X.shape[-1]
616
f"LPBPCG algorithm is not applicable when the number of A rows (={m})"
617
f" is smaller than 3 x the number of requested eigenpairs (={n})"
620
method = "ortho" if method is None else method
626
"niter": 1000 if niter is None else niter,
633
bparams = {"largest": True if largest is None else largest}
635
if method == "ortho":
636
if ortho_iparams is not None:
637
iparams.update(ortho_iparams)
638
if ortho_fparams is not None:
639
fparams.update(ortho_fparams)
640
if ortho_bparams is not None:
641
bparams.update(ortho_bparams)
642
iparams["ortho_i_max"] = iparams.get("ortho_i_max", 3)
643
iparams["ortho_j_max"] = iparams.get("ortho_j_max", 3)
644
fparams["ortho_tol"] = fparams.get("ortho_tol", tol)
645
fparams["ortho_tol_drop"] = fparams.get("ortho_tol_drop", tol)
646
fparams["ortho_tol_replace"] = fparams.get("ortho_tol_replace", tol)
647
bparams["ortho_use_drop"] = bparams.get("ortho_use_drop", False)
649
if not torch.jit.is_scripting():
650
LOBPCG.call_tracker = LOBPCG_call_tracker
653
N = int(torch.prod(torch.tensor(A.shape[:-2])))
654
bA = A.reshape((N,) + A.shape[-2:])
655
bB = B.reshape((N,) + A.shape[-2:]) if B is not None else None
656
bX = X.reshape((N,) + X.shape[-2:]) if X is not None else None
657
bE = torch.empty((N, k), dtype=dtype, device=device)
658
bXret = torch.empty((N, m, k), dtype=dtype, device=device)
662
B_ = bB[i] if bB is not None else None
664
torch.randn((m, n), dtype=dtype, device=device) if bX is None else bX[i]
666
assert len(X_.shape) == 2 and X_.shape == (m, n), (X_.shape, (m, n))
667
iparams["batch_index"] = i
668
worker = LOBPCG(A_, B_, X_, iK, iparams, fparams, bparams, method, tracker)
671
bXret[i] = worker.X[:, :k]
673
if not torch.jit.is_scripting():
674
LOBPCG.call_tracker = LOBPCG_call_tracker_orig
676
return bE.reshape(A.shape[:-2] + (k,)), bXret.reshape(A.shape[:-2] + (m, k))
678
X = torch.randn((m, n), dtype=dtype, device=device) if X is None else X
679
assert len(X.shape) == 2 and X.shape == (m, n), (X.shape, (m, n))
681
worker = LOBPCG(A, B, X, iK, iparams, fparams, bparams, method, tracker)
685
if not torch.jit.is_scripting():
686
LOBPCG.call_tracker = LOBPCG_call_tracker_orig
688
return worker.E[:k], worker.X[:, :k]
692
"""Worker class of LOBPCG methods."""
699
iK: Optional[Tensor],
700
iparams: Dict[str, int],
701
fparams: Dict[str, float],
702
bparams: Dict[str, bool],
710
self.iparams = iparams
711
self.fparams = fparams
712
self.bparams = bparams
714
self.tracker = tracker
720
self.E = torch.zeros((n,), dtype=X.dtype, device=X.device)
721
self.R = torch.zeros((m, n), dtype=X.dtype, device=X.device)
722
self.S = torch.zeros((m, 3 * n), dtype=X.dtype, device=X.device)
723
self.tvars: Dict[str, Tensor] = {}
724
self.ivars: Dict[str, int] = {"istep": 0}
725
self.fvars: Dict[str, float] = {"_": 0.0}
726
self.bvars: Dict[str, bool] = {"_": False}
730
lines += [f" iparams={self.iparams}"]
731
lines += [f" fparams={self.fparams}"]
732
lines += [f" bparams={self.bparams}"]
733
lines += [f" ivars={self.ivars}"]
734
lines += [f" fvars={self.fvars}"]
735
lines += [f" bvars={self.bvars}"]
736
lines += [f" tvars={self.tvars}"]
737
lines += [f" A={self.A}"]
738
lines += [f" B={self.B}"]
739
lines += [f" iK={self.iK}"]
740
lines += [f" X={self.X}"]
741
lines += [f" E={self.E}"]
748
"""Set and update iteration variables."""
749
if self.ivars["istep"] == 0:
750
X_norm = float(torch.norm(self.X))
752
A_norm = float(torch.norm(_utils.matmul(self.A, self.X))) * iX_norm
753
B_norm = float(torch.norm(_utils.matmul(self.B, self.X))) * iX_norm
754
self.fvars["X_norm"] = X_norm
755
self.fvars["A_norm"] = A_norm
756
self.fvars["B_norm"] = B_norm
757
self.ivars["iterations_left"] = self.iparams["niter"]
758
self.ivars["converged_count"] = 0
759
self.ivars["converged_end"] = 0
761
if self.method == "ortho":
766
self.ivars["iterations_left"] = self.ivars["iterations_left"] - 1
767
self.ivars["istep"] = self.ivars["istep"] + 1
769
def update_residual(self):
770
"""Update residual R from A, B, X, E."""
772
self.R = mm(self.A, self.X) - mm(self.B, self.X) * self.E
774
def update_converged_count(self):
775
"""Determine the number of converged eigenpairs using backward stable
776
convergence criterion, see discussion in Sec 4.3 of [DuerschEtal2018].
778
Users may redefine this method for custom convergence criteria.
781
prev_count = self.ivars["converged_count"]
782
tol = self.fparams["tol"]
783
A_norm = self.fvars["A_norm"]
784
B_norm = self.fvars["B_norm"]
785
E, X, R = self.E, self.X, self.R
787
torch.norm(R, 2, (0,))
788
* (torch.norm(X, 2, (0,)) * (A_norm + E[: X.shape[-1]] * B_norm)) ** -1
790
converged = rerr.real < tol
800
), f"the number of converged eigenpairs (was {prev_count}, got {count}) cannot decrease"
801
self.ivars["converged_count"] = count
802
self.tvars["rerr"] = rerr
805
def stop_iteration(self):
806
"""Return True to stop iterations.
808
Note that tracker (if defined) can force-stop iterations by
809
setting ``worker.bvars['force_stop'] = True``.
812
self.bvars.get("force_stop", False)
813
or self.ivars["iterations_left"] == 0
814
or self.ivars["converged_count"] >= self.iparams["k"]
818
"""Run LOBPCG iterations.
820
Use this method as a template for implementing LOBPCG
821
iteration scheme with custom tracker that is compatible with
826
if not torch.jit.is_scripting() and self.tracker is not None:
829
while not self.stop_iteration():
832
if not torch.jit.is_scripting() and self.tracker is not None:
836
def call_tracker(self):
837
"""Interface for tracking iteration process in Python mode.
839
Tracking the iteration process is disabled in TorchScript
840
mode. In fact, one should specify tracker=None when JIT
841
compiling functions using lobpcg.
847
def _update_basic(self):
849
Update or initialize iteration variables when `method == "basic"`.
852
ns = self.ivars["converged_end"]
853
nc = self.ivars["converged_count"]
854
n = self.iparams["n"]
855
largest = self.bparams["largest"]
857
if self.ivars["istep"] == 0:
858
Ri = self._get_rayleigh_ritz_transform(self.X)
859
M = _utils.qform(_utils.qform(self.A, self.X), Ri)
860
E, Z = _utils.symeig(M, largest)
861
self.X[:] = mm(self.X, mm(Ri, Z))
864
self.update_residual()
865
nc = self.update_converged_count()
866
self.S[..., :n] = self.X
868
W = _utils.matmul(self.iK, self.R)
869
self.ivars["converged_end"] = ns = n + np + W.shape[-1]
870
self.S[:, n + np : ns] = W
872
S_ = self.S[:, nc:ns]
873
Ri = self._get_rayleigh_ritz_transform(S_)
874
M = _utils.qform(_utils.qform(self.A, S_), Ri)
875
E_, Z = _utils.symeig(M, largest)
876
self.X[:, nc:] = mm(S_, mm(Ri, Z[:, : n - nc]))
877
self.E[nc:] = E_[: n - nc]
878
P = mm(S_, mm(Ri, Z[:, n : 2 * n - nc]))
881
self.update_residual()
882
nc = self.update_converged_count()
883
self.S[..., :n] = self.X
884
self.S[:, n : n + np] = P
885
W = _utils.matmul(self.iK, self.R[:, nc:])
887
self.ivars["converged_end"] = ns = n + np + W.shape[-1]
888
self.S[:, n + np : ns] = W
890
def _update_ortho(self):
892
Update or initialize iteration variables when `method == "ortho"`.
895
ns = self.ivars["converged_end"]
896
nc = self.ivars["converged_count"]
897
n = self.iparams["n"]
898
largest = self.bparams["largest"]
900
if self.ivars["istep"] == 0:
901
Ri = self._get_rayleigh_ritz_transform(self.X)
902
M = _utils.qform(_utils.qform(self.A, self.X), Ri)
903
E, Z = _utils.symeig(M, largest)
904
self.X = mm(self.X, mm(Ri, Z))
905
self.update_residual()
907
nc = self.update_converged_count()
908
self.S[:, :n] = self.X
909
W = self._get_ortho(self.R, self.X)
910
ns = self.ivars["converged_end"] = n + np + W.shape[-1]
911
self.S[:, n + np : ns] = W
914
S_ = self.S[:, nc:ns]
916
E_, Z = _utils.symeig(_utils.qform(self.A, S_), largest)
919
self.X[:, nc:] = mm(S_, Z[:, : n - nc])
920
self.E[nc:] = E_[: n - nc]
921
P = mm(S_, mm(Z[:, n - nc :], _utils.basis(Z[: n - nc, n - nc :].mT)))
925
self.update_residual()
926
nc = self.update_converged_count()
929
self.S[:, :n] = self.X
930
self.S[:, n : n + np] = P
931
W = self._get_ortho(self.R[:, nc:], self.S[:, : n + np])
932
ns = self.ivars["converged_end"] = n + np + W.shape[-1]
933
self.S[:, n + np : ns] = W
935
def _get_rayleigh_ritz_transform(self, S):
936
"""Return a transformation matrix that is used in Rayleigh-Ritz
937
procedure for reducing a general eigenvalue problem :math:`(S^TAS)
938
C = (S^TBS) C E` to a standard eigenvalue problem :math: `(Ri^T
939
S^TAS Ri) Z = Z E` where `C = Ri Z`.
941
.. note:: In the original Rayleight-Ritz procedure in
942
[DuerschEtal2018], the problem is formulated as follows::
946
D = (<diagonal matrix of SBS>) ** -1/2
947
R^T R = Cholesky(D SBS D)
949
solve symeig problem Ri^T SAS Ri Z = Theta Z
952
To reduce the number of matrix products (denoted by empty
953
space between matrices), here we introduce element-wise
954
products (denoted by symbol `*`) so that the Rayleight-Ritz
959
d = (<diagonal of SBS>) ** -1/2 # this is 1-d column vector
960
dd = d d^T # this is 2-d matrix
961
R^T R = Cholesky(dd * SBS)
962
Ri = R^-1 * d # broadcasting
963
solve symeig problem Ri^T SAS Ri Z = Theta Z
966
where `dd` is 2-d matrix that replaces matrix products `D M
967
D` with one element-wise product `M * dd`; and `d` replaces
968
matrix product `D M` with element-wise product `M *
969
d`. Also, creating the diagonal matrix `D` is avoided.
972
S (Tensor): the matrix basis for the search subspace, size is
976
Ri (tensor): upper-triangular transformation matrix of size
982
SBS = _utils.qform(B, S)
983
d_row = SBS.diagonal(0, -2, -1) ** -0.5
984
d_col = d_row.reshape(d_row.shape[0], 1)
986
R = torch.linalg.cholesky((SBS * d_row) * d_col, upper=True)
987
return torch.linalg.solve_triangular(
988
R, d_row.diag_embed(), upper=True, left=False
991
def _get_svqb(self, U: Tensor, drop: bool, tau: float) -> Tensor:
992
"""Return B-orthonormal U.
994
.. note:: When `drop` is `False` then `svqb` is based on the
995
Algorithm 4 from [DuerschPhD2015] that is a slight
996
modification of the corresponding algorithm
997
introduced in [StathopolousWu2002].
1001
U (Tensor) : initial approximation, size is (m, n)
1002
drop (bool) : when True, drop columns that
1003
contribution to the `span([U])` is small.
1004
tau (float) : positive tolerance
1008
U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`), size
1009
is (m, n1), where `n1 = n` if `drop` is `False,
1010
otherwise `n1 <= n`.
1013
if torch.numel(U) == 0:
1015
UBU = _utils.qform(self.B, U)
1016
d = UBU.diagonal(0, -2, -1)
1025
nz = torch.where(abs(d) != 0.0)
1026
assert len(nz) == 1, nz
1027
if len(nz[0]) < len(d):
1029
if torch.numel(U) == 0:
1031
UBU = _utils.qform(self.B, U)
1032
d = UBU.diagonal(0, -2, -1)
1033
nz = torch.where(abs(d) != 0.0)
1034
assert len(nz[0]) == len(d)
1037
d_col = (d**-0.5).reshape(d.shape[0], 1)
1038
DUBUD = (UBU * d_col) * d_col.mT
1039
E, Z = _utils.symeig(DUBUD)
1040
t = tau * abs(E).max()
1042
keep = torch.where(E > t)
1043
assert len(keep) == 1, keep
1046
d_col = d_col[keep[0]]
1048
E[(torch.where(E < t))[0]] = t
1050
return torch.matmul(U * d_col.mT, Z * E**-0.5)
1052
def _get_ortho(self, U, V):
1053
"""Return B-orthonormal U with columns are B-orthogonal to V.
1055
.. note:: When `bparams["ortho_use_drop"] == False` then
1056
`_get_ortho` is based on the Algorithm 3 from
1057
[DuerschPhD2015] that is a slight modification of
1058
the corresponding algorithm introduced in
1059
[StathopolousWu2002]. Otherwise, the method
1060
implements Algorithm 6 from [DuerschPhD2015]
1062
.. note:: If all U columns are B-collinear to V then the
1063
returned tensor U will be empty.
1067
U (Tensor) : initial approximation, size is (m, n)
1068
V (Tensor) : B-orthogonal external basis, size is (m, k)
1072
U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`)
1073
such that :math:`V^T B U=0`, size is (m, n1),
1074
where `n1 = n` if `drop` is `False, otherwise
1078
mm_B = _utils.matmul
1079
m = self.iparams["m"]
1080
tau_ortho = self.fparams["ortho_tol"]
1081
tau_drop = self.fparams["ortho_tol_drop"]
1082
tau_replace = self.fparams["ortho_tol_replace"]
1083
i_max = self.iparams["ortho_i_max"]
1084
j_max = self.iparams["ortho_j_max"]
1087
use_drop = self.bparams["ortho_use_drop"]
1090
for vkey in list(self.fvars.keys()):
1091
if vkey.startswith("ortho_") and vkey.endswith("_rerr"):
1092
self.fvars.pop(vkey)
1093
self.ivars.pop("ortho_i", 0)
1094
self.ivars.pop("ortho_j", 0)
1096
BV_norm = torch.norm(mm_B(self.B, V))
1097
BU = mm_B(self.B, U)
1101
for i in range(i_max):
1105
for j in range(j_max):
1107
U = self._get_svqb(U, drop, tau_svqb)
1109
tau_svqb = tau_replace
1111
U = self._get_svqb(U, False, tau_replace)
1112
if torch.numel(U) == 0:
1114
self.ivars["ortho_i"] = i
1115
self.ivars["ortho_j"] = j
1117
BU = mm_B(self.B, U)
1119
U_norm = torch.norm(U)
1120
BU_norm = torch.norm(BU)
1121
R = UBU - torch.eye(UBU.shape[-1], device=UBU.device, dtype=UBU.dtype)
1122
R_norm = torch.norm(R)
1124
rerr = float(R_norm) * float(BU_norm * U_norm) ** -1
1125
vkey = f"ortho_UBUmI_rerr[{i}, {j}]"
1126
self.fvars[vkey] = rerr
1127
if rerr < tau_ortho:
1130
VBU_norm = torch.norm(VBU)
1131
U_norm = torch.norm(U)
1132
rerr = float(VBU_norm) * float(BV_norm * U_norm) ** -1
1133
vkey = f"ortho_VBU_rerr[{i}]"
1134
self.fvars[vkey] = rerr
1135
if rerr < tau_ortho:
1137
if m < U.shape[-1] + V.shape[-1]:
1141
assert B is not None
1143
"Overdetermined shape of U:"
1144
f" #B-cols(={B.shape[-1]}) >= #U-cols(={U.shape[-1]}) + #V-cols(={V.shape[-1]}) must hold"
1146
self.ivars["ortho_i"] = i
1147
self.ivars["ortho_j"] = j
1153
LOBPCG_call_tracker_orig = LOBPCG.call_tracker
1156
def LOBPCG_call_tracker(self):