pytorch

Форк
0
/
_lobpcg.py 
1157 строк · 42.5 Кб
1
# mypy: allow-untyped-defs
2
"""Locally Optimal Block Preconditioned Conjugate Gradient methods."""
3
# Author: Pearu Peterson
4
# Created: February 2020
5

6
from typing import Dict, Optional, Tuple
7

8
import torch
9
from torch import _linalg_utils as _utils, Tensor
10
from torch.overrides import handle_torch_function, has_torch_function
11

12

13
__all__ = ["lobpcg"]
14

15

16
def _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U):
17
    # compute F, such that F_ij = (d_j - d_i)^{-1} for i != j, F_ii = 0
18
    F = D.unsqueeze(-2) - D.unsqueeze(-1)
19
    F.diagonal(dim1=-2, dim2=-1).fill_(float("inf"))
20
    F.pow_(-1)
21

22
    # A.grad = U (D.grad + (U^T U.grad * F)) U^T
23
    Ut = U.mT.contiguous()
24
    res = torch.matmul(
25
        U, torch.matmul(torch.diag_embed(D_grad) + torch.matmul(Ut, U_grad) * F, Ut)
26
    )
27

28
    return res
29

30

31
def _polynomial_coefficients_given_roots(roots):
32
    """
33
    Given the `roots` of a polynomial, find the polynomial's coefficients.
34

35
    If roots = (r_1, ..., r_n), then the method returns
36
    coefficients (a_0, a_1, ..., a_n (== 1)) so that
37
    p(x) = (x - r_1) * ... * (x - r_n)
38
         = x^n + a_{n-1} * x^{n-1} + ... a_1 * x_1 + a_0
39

40
    Note: for better performance requires writing a low-level kernel
41
    """
42
    poly_order = roots.shape[-1]
43
    poly_coeffs_shape = list(roots.shape)
44
    # we assume p(x) = x^n + a_{n-1} * x^{n-1} + ... + a_1 * x + a_0,
45
    # so poly_coeffs = {a_0, ..., a_n, a_{n+1}(== 1)},
46
    # but we insert one extra coefficient to enable better vectorization below
47
    poly_coeffs_shape[-1] += 2
48
    poly_coeffs = roots.new_zeros(poly_coeffs_shape)
49
    poly_coeffs[..., 0] = 1
50
    poly_coeffs[..., -1] = 1
51

52
    # perform the Horner's rule
53
    for i in range(1, poly_order + 1):
54
        # note that it is computationally hard to compute backward for this method,
55
        # because then given the coefficients it would require finding the roots and/or
56
        # calculating the sensitivity based on the Vieta's theorem.
57
        # So the code below tries to circumvent the explicit root finding by series
58
        # of operations on memory copies imitating the Horner's method.
59
        # The memory copies are required to construct nodes in the computational graph
60
        # by exploting the explicit (not in-place, separate node for each step)
61
        # recursion of the Horner's method.
62
        # Needs more memory, O(... * k^2), but with only O(... * k^2) complexity.
63
        poly_coeffs_new = poly_coeffs.clone() if roots.requires_grad else poly_coeffs
64
        out = poly_coeffs_new.narrow(-1, poly_order - i, i + 1)
65
        out -= roots.narrow(-1, i - 1, 1) * poly_coeffs.narrow(
66
            -1, poly_order - i + 1, i + 1
67
        )
68
        poly_coeffs = poly_coeffs_new
69

70
    return poly_coeffs.narrow(-1, 1, poly_order + 1)
71

72

73
def _polynomial_value(poly, x, zero_power, transition):
74
    """
75
    A generic method for computing poly(x) using the Horner's rule.
76

77
    Args:
78
      poly (Tensor): the (possibly batched) 1D Tensor representing
79
                     polynomial coefficients such that
80
                     poly[..., i] = (a_{i_0}, ..., a{i_n} (==1)), and
81
                     poly(x) = poly[..., 0] * zero_power + ... + poly[..., n] * x^n
82

83
      x (Tensor): the value (possible batched) to evalate the polynomial `poly` at.
84

85
      zero_power (Tensor): the representation of `x^0`. It is application-specific.
86

87
      transition (Callable): the function that accepts some intermediate result `int_val`,
88
                             the `x` and a specific polynomial coefficient
89
                             `poly[..., k]` for some iteration `k`.
90
                             It basically performs one iteration of the Horner's rule
91
                             defined as `x * int_val + poly[..., k] * zero_power`.
92
                             Note that `zero_power` is not a parameter,
93
                             because the step `+ poly[..., k] * zero_power` depends on `x`,
94
                             whether it is a vector, a matrix, or something else, so this
95
                             functionality is delegated to the user.
96
    """
97

98
    res = zero_power.clone()
99
    for k in range(poly.size(-1) - 2, -1, -1):
100
        res = transition(res, x, poly[..., k])
101
    return res
102

103

104
def _matrix_polynomial_value(poly, x, zero_power=None):
105
    """
106
    Evaluates `poly(x)` for the (batched) matrix input `x`.
107
    Check out `_polynomial_value` function for more details.
108
    """
109

110
    # matrix-aware Horner's rule iteration
111
    def transition(curr_poly_val, x, poly_coeff):
112
        res = x.matmul(curr_poly_val)
113
        res.diagonal(dim1=-2, dim2=-1).add_(poly_coeff.unsqueeze(-1))
114
        return res
115

116
    if zero_power is None:
117
        zero_power = torch.eye(
118
            x.size(-1), x.size(-1), dtype=x.dtype, device=x.device
119
        ).view(*([1] * len(list(x.shape[:-2]))), x.size(-1), x.size(-1))
120

121
    return _polynomial_value(poly, x, zero_power, transition)
122

123

124
def _vector_polynomial_value(poly, x, zero_power=None):
125
    """
126
    Evaluates `poly(x)` for the (batched) vector input `x`.
127
    Check out `_polynomial_value` function for more details.
128
    """
129

130
    # vector-aware Horner's rule iteration
131
    def transition(curr_poly_val, x, poly_coeff):
132
        res = torch.addcmul(poly_coeff.unsqueeze(-1), x, curr_poly_val)
133
        return res
134

135
    if zero_power is None:
136
        zero_power = x.new_ones(1).expand(x.shape)
137

138
    return _polynomial_value(poly, x, zero_power, transition)
139

140

141
def _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest):
142
    # compute a projection operator onto an orthogonal subspace spanned by the
143
    # columns of U defined as (I - UU^T)
144
    Ut = U.mT.contiguous()
145
    proj_U_ortho = -U.matmul(Ut)
146
    proj_U_ortho.diagonal(dim1=-2, dim2=-1).add_(1)
147

148
    # compute U_ortho, a basis for the orthogonal complement to the span(U),
149
    # by projecting a random [..., m, m - k] matrix onto the subspace spanned
150
    # by the columns of U.
151
    #
152
    # fix generator for determinism
153
    gen = torch.Generator(A.device)
154

155
    # orthogonal complement to the span(U)
156
    U_ortho = proj_U_ortho.matmul(
157
        torch.randn(
158
            (*A.shape[:-1], A.size(-1) - D.size(-1)),
159
            dtype=A.dtype,
160
            device=A.device,
161
            generator=gen,
162
        )
163
    )
164
    U_ortho_t = U_ortho.mT.contiguous()
165

166
    # compute the coefficients of the characteristic polynomial of the tensor D.
167
    # Note that D is diagonal, so the diagonal elements are exactly the roots
168
    # of the characteristic polynomial.
169
    chr_poly_D = _polynomial_coefficients_given_roots(D)
170

171
    # the code belows finds the explicit solution to the Sylvester equation
172
    # U_ortho^T A U_ortho dX - dX D = -U_ortho^T A U
173
    # and incorporates it into the whole gradient stored in the `res` variable.
174
    #
175
    # Equivalent to the following naive implementation:
176
    # res = A.new_zeros(A.shape)
177
    # p_res = A.new_zeros(*A.shape[:-1], D.size(-1))
178
    # for k in range(1, chr_poly_D.size(-1)):
179
    #     p_res.zero_()
180
    #     for i in range(0, k):
181
    #         p_res += (A.matrix_power(k - 1 - i) @ U_grad) * D.pow(i).unsqueeze(-2)
182
    #     res -= chr_poly_D[k] * (U_ortho @ poly_D_at_A.inverse() @ U_ortho_t @  p_res @ U.t())
183
    #
184
    # Note that dX is a differential, so the gradient contribution comes from the backward sensitivity
185
    # Tr(f(U_grad, D_grad, A, U, D)^T dX) = Tr(g(U_grad, A, U, D)^T dA) for some functions f and g,
186
    # and we need to compute g(U_grad, A, U, D)
187
    #
188
    # The naive implementation is based on the paper
189
    # Hu, Qingxi, and Daizhan Cheng.
190
    # "The polynomial solution to the Sylvester matrix equation."
191
    # Applied mathematics letters 19.9 (2006): 859-864.
192
    #
193
    # We can modify the computation of `p_res` from above in a more efficient way
194
    # p_res =   U_grad * (chr_poly_D[1] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k)).unsqueeze(-2)
195
    #       + A U_grad * (chr_poly_D[2] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k - 1)).unsqueeze(-2)
196
    #       + ...
197
    #       + A.matrix_power(k - 1) U_grad * chr_poly_D[k]
198
    # Note that this saves us from redundant matrix products with A (elimination of matrix_power)
199
    U_grad_projected = U_grad
200
    series_acc = U_grad_projected.new_zeros(U_grad_projected.shape)
201
    for k in range(1, chr_poly_D.size(-1)):
202
        poly_D = _vector_polynomial_value(chr_poly_D[..., k:], D)
203
        series_acc += U_grad_projected * poly_D.unsqueeze(-2)
204
        U_grad_projected = A.matmul(U_grad_projected)
205

206
    # compute chr_poly_D(A) which essentially is:
207
    #
208
    # chr_poly_D_at_A = A.new_zeros(A.shape)
209
    # for k in range(chr_poly_D.size(-1)):
210
    #     chr_poly_D_at_A += chr_poly_D[k] * A.matrix_power(k)
211
    #
212
    # Note, however, for better performance we use the Horner's rule
213
    chr_poly_D_at_A = _matrix_polynomial_value(chr_poly_D, A)
214

215
    # compute the action of `chr_poly_D_at_A` restricted to U_ortho_t
216
    chr_poly_D_at_A_to_U_ortho = torch.matmul(
217
        U_ortho_t, torch.matmul(chr_poly_D_at_A, U_ortho)
218
    )
219
    # we need to invert 'chr_poly_D_at_A_to_U_ortho`, for that we compute its
220
    # Cholesky decomposition and then use `torch.cholesky_solve` for better stability.
221
    # Cholesky decomposition requires the input to be positive-definite.
222
    # Note that `chr_poly_D_at_A_to_U_ortho` is positive-definite if
223
    # 1. `largest` == False, or
224
    # 2. `largest` == True and `k` is even
225
    # under the assumption that `A` has distinct eigenvalues.
226
    #
227
    # check if `chr_poly_D_at_A_to_U_ortho` is positive-definite or negative-definite
228
    chr_poly_D_at_A_to_U_ortho_sign = -1 if (largest and (k % 2 == 1)) else +1
229
    chr_poly_D_at_A_to_U_ortho_L = torch.linalg.cholesky(
230
        chr_poly_D_at_A_to_U_ortho_sign * chr_poly_D_at_A_to_U_ortho
231
    )
232

233
    # compute the gradient part in span(U)
234
    res = _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U)
235

236
    # incorporate the Sylvester equation solution into the full gradient
237
    # it resides in span(U_ortho)
238
    res -= U_ortho.matmul(
239
        chr_poly_D_at_A_to_U_ortho_sign
240
        * torch.cholesky_solve(
241
            U_ortho_t.matmul(series_acc), chr_poly_D_at_A_to_U_ortho_L
242
        )
243
    ).matmul(Ut)
244

245
    return res
246

247

248
def _symeig_backward(D_grad, U_grad, A, D, U, largest):
249
    # if `U` is square, then the columns of `U` is a complete eigenspace
250
    if U.size(-1) == U.size(-2):
251
        return _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U)
252
    else:
253
        return _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest)
254

255

256
class LOBPCGAutogradFunction(torch.autograd.Function):
257
    @staticmethod
258
    def forward(  # type: ignore[override]
259
        ctx,
260
        A: Tensor,
261
        k: Optional[int] = None,
262
        B: Optional[Tensor] = None,
263
        X: Optional[Tensor] = None,
264
        n: Optional[int] = None,
265
        iK: Optional[Tensor] = None,
266
        niter: Optional[int] = None,
267
        tol: Optional[float] = None,
268
        largest: Optional[bool] = None,
269
        method: Optional[str] = None,
270
        tracker: None = None,
271
        ortho_iparams: Optional[Dict[str, int]] = None,
272
        ortho_fparams: Optional[Dict[str, float]] = None,
273
        ortho_bparams: Optional[Dict[str, bool]] = None,
274
    ) -> Tuple[Tensor, Tensor]:
275
        # makes sure that input is contiguous for efficiency.
276
        # Note: autograd does not support dense gradients for sparse input yet.
277
        A = A.contiguous() if (not A.is_sparse) else A
278
        if B is not None:
279
            B = B.contiguous() if (not B.is_sparse) else B
280

281
        D, U = _lobpcg(
282
            A,
283
            k,
284
            B,
285
            X,
286
            n,
287
            iK,
288
            niter,
289
            tol,
290
            largest,
291
            method,
292
            tracker,
293
            ortho_iparams,
294
            ortho_fparams,
295
            ortho_bparams,
296
        )
297

298
        ctx.save_for_backward(A, B, D, U)
299
        ctx.largest = largest
300

301
        return D, U
302

303
    @staticmethod
304
    def backward(ctx, D_grad, U_grad):
305
        A_grad = B_grad = None
306
        grads = [None] * 14
307

308
        A, B, D, U = ctx.saved_tensors
309
        largest = ctx.largest
310

311
        # lobpcg.backward has some limitations. Checks for unsupported input
312
        if A.is_sparse or (B is not None and B.is_sparse and ctx.needs_input_grad[2]):
313
            raise ValueError(
314
                "lobpcg.backward does not support sparse input yet."
315
                "Note that lobpcg.forward does though."
316
            )
317
        if (
318
            A.dtype in (torch.complex64, torch.complex128)
319
            or B is not None
320
            and B.dtype in (torch.complex64, torch.complex128)
321
        ):
322
            raise ValueError(
323
                "lobpcg.backward does not support complex input yet."
324
                "Note that lobpcg.forward does though."
325
            )
326
        if B is not None:
327
            raise ValueError(
328
                "lobpcg.backward does not support backward with B != I yet."
329
            )
330

331
        if largest is None:
332
            largest = True
333

334
        # symeig backward
335
        if B is None:
336
            A_grad = _symeig_backward(D_grad, U_grad, A, D, U, largest)
337

338
        # A has index 0
339
        grads[0] = A_grad
340
        # B has index 2
341
        grads[2] = B_grad
342
        return tuple(grads)
343

344

345
def lobpcg(
346
    A: Tensor,
347
    k: Optional[int] = None,
348
    B: Optional[Tensor] = None,
349
    X: Optional[Tensor] = None,
350
    n: Optional[int] = None,
351
    iK: Optional[Tensor] = None,
352
    niter: Optional[int] = None,
353
    tol: Optional[float] = None,
354
    largest: Optional[bool] = None,
355
    method: Optional[str] = None,
356
    tracker: None = None,
357
    ortho_iparams: Optional[Dict[str, int]] = None,
358
    ortho_fparams: Optional[Dict[str, float]] = None,
359
    ortho_bparams: Optional[Dict[str, bool]] = None,
360
) -> Tuple[Tensor, Tensor]:
361
    """Find the k largest (or smallest) eigenvalues and the corresponding
362
    eigenvectors of a symmetric positive definite generalized
363
    eigenvalue problem using matrix-free LOBPCG methods.
364

365
    This function is a front-end to the following LOBPCG algorithms
366
    selectable via `method` argument:
367

368
      `method="basic"` - the LOBPCG method introduced by Andrew
369
      Knyazev, see [Knyazev2001]. A less robust method, may fail when
370
      Cholesky is applied to singular input.
371

372
      `method="ortho"` - the LOBPCG method with orthogonal basis
373
      selection [StathopoulosEtal2002]. A robust method.
374

375
    Supported inputs are dense, sparse, and batches of dense matrices.
376

377
    .. note:: In general, the basic method spends least time per
378
      iteration. However, the robust methods converge much faster and
379
      are more stable. So, the usage of the basic method is generally
380
      not recommended but there exist cases where the usage of the
381
      basic method may be preferred.
382

383
    .. warning:: The backward method does not support sparse and complex inputs.
384
      It works only when `B` is not provided (i.e. `B == None`).
385
      We are actively working on extensions, and the details of
386
      the algorithms are going to be published promptly.
387

388
    .. warning:: While it is assumed that `A` is symmetric, `A.grad` is not.
389
      To make sure that `A.grad` is symmetric, so that `A - t * A.grad` is symmetric
390
      in first-order optimization routines, prior to running `lobpcg`
391
      we do the following symmetrization map: `A -> (A + A.t()) / 2`.
392
      The map is performed only when the `A` requires gradients.
393

394
    Args:
395

396
      A (Tensor): the input tensor of size :math:`(*, m, m)`
397

398
      B (Tensor, optional): the input tensor of size :math:`(*, m,
399
                  m)`. When not specified, `B` is interpreted as
400
                  identity matrix.
401

402
      X (tensor, optional): the input tensor of size :math:`(*, m, n)`
403
                  where `k <= n <= m`. When specified, it is used as
404
                  initial approximation of eigenvectors. X must be a
405
                  dense tensor.
406

407
      iK (tensor, optional): the input tensor of size :math:`(*, m,
408
                  m)`. When specified, it will be used as preconditioner.
409

410
      k (integer, optional): the number of requested
411
                  eigenpairs. Default is the number of :math:`X`
412
                  columns (when specified) or `1`.
413

414
      n (integer, optional): if :math:`X` is not specified then `n`
415
                  specifies the size of the generated random
416
                  approximation of eigenvectors. Default value for `n`
417
                  is `k`. If :math:`X` is specified, the value of `n`
418
                  (when specified) must be the number of :math:`X`
419
                  columns.
420

421
      tol (float, optional): residual tolerance for stopping
422
                 criterion. Default is `feps ** 0.5` where `feps` is
423
                 smallest non-zero floating-point number of the given
424
                 input tensor `A` data type.
425

426
      largest (bool, optional): when True, solve the eigenproblem for
427
                 the largest eigenvalues. Otherwise, solve the
428
                 eigenproblem for smallest eigenvalues. Default is
429
                 `True`.
430

431
      method (str, optional): select LOBPCG method. See the
432
                 description of the function above. Default is
433
                 "ortho".
434

435
      niter (int, optional): maximum number of iterations. When
436
                 reached, the iteration process is hard-stopped and
437
                 the current approximation of eigenpairs is returned.
438
                 For infinite iteration but until convergence criteria
439
                 is met, use `-1`.
440

441
      tracker (callable, optional) : a function for tracing the
442
                 iteration process. When specified, it is called at
443
                 each iteration step with LOBPCG instance as an
444
                 argument. The LOBPCG instance holds the full state of
445
                 the iteration process in the following attributes:
446

447
                   `iparams`, `fparams`, `bparams` - dictionaries of
448
                   integer, float, and boolean valued input
449
                   parameters, respectively
450

451
                   `ivars`, `fvars`, `bvars`, `tvars` - dictionaries
452
                   of integer, float, boolean, and Tensor valued
453
                   iteration variables, respectively.
454

455
                   `A`, `B`, `iK` - input Tensor arguments.
456

457
                   `E`, `X`, `S`, `R` - iteration Tensor variables.
458

459
                 For instance:
460

461
                   `ivars["istep"]` - the current iteration step
462
                   `X` - the current approximation of eigenvectors
463
                   `E` - the current approximation of eigenvalues
464
                   `R` - the current residual
465
                   `ivars["converged_count"]` - the current number of converged eigenpairs
466
                   `tvars["rerr"]` - the current state of convergence criteria
467

468
                 Note that when `tracker` stores Tensor objects from
469
                 the LOBPCG instance, it must make copies of these.
470

471
                 If `tracker` sets `bvars["force_stop"] = True`, the
472
                 iteration process will be hard-stopped.
473

474
      ortho_iparams, ortho_fparams, ortho_bparams (dict, optional):
475
                 various parameters to LOBPCG algorithm when using
476
                 `method="ortho"`.
477

478
    Returns:
479

480
      E (Tensor): tensor of eigenvalues of size :math:`(*, k)`
481

482
      X (Tensor): tensor of eigenvectors of size :math:`(*, m, k)`
483

484
    References:
485

486
      [Knyazev2001] Andrew V. Knyazev. (2001) Toward the Optimal
487
      Preconditioned Eigensolver: Locally Optimal Block Preconditioned
488
      Conjugate Gradient Method. SIAM J. Sci. Comput., 23(2),
489
      517-541. (25 pages)
490
      https://epubs.siam.org/doi/abs/10.1137/S1064827500366124
491

492
      [StathopoulosEtal2002] Andreas Stathopoulos and Kesheng
493
      Wu. (2002) A Block Orthogonalization Procedure with Constant
494
      Synchronization Requirements. SIAM J. Sci. Comput., 23(6),
495
      2165-2182. (18 pages)
496
      https://epubs.siam.org/doi/10.1137/S1064827500370883
497

498
      [DuerschEtal2018] Jed A. Duersch, Meiyue Shao, Chao Yang, Ming
499
      Gu. (2018) A Robust and Efficient Implementation of LOBPCG.
500
      SIAM J. Sci. Comput., 40(5), C655-C676. (22 pages)
501
      https://epubs.siam.org/doi/abs/10.1137/17M1129830
502

503
    """
504

505
    if not torch.jit.is_scripting():
506
        tensor_ops = (A, B, X, iK)
507
        if not set(map(type, tensor_ops)).issubset(
508
            (torch.Tensor, type(None))
509
        ) and has_torch_function(tensor_ops):
510
            return handle_torch_function(
511
                lobpcg,
512
                tensor_ops,
513
                A,
514
                k=k,
515
                B=B,
516
                X=X,
517
                n=n,
518
                iK=iK,
519
                niter=niter,
520
                tol=tol,
521
                largest=largest,
522
                method=method,
523
                tracker=tracker,
524
                ortho_iparams=ortho_iparams,
525
                ortho_fparams=ortho_fparams,
526
                ortho_bparams=ortho_bparams,
527
            )
528

529
    if not torch._jit_internal.is_scripting():
530
        if A.requires_grad or (B is not None and B.requires_grad):
531
            # While it is expected that `A` is symmetric,
532
            # the `A_grad` might be not. Therefore we perform the trick below,
533
            # so that `A_grad` becomes symmetric.
534
            # The symmetrization is important for first-order optimization methods,
535
            # so that (A - alpha * A_grad) is still a symmetric matrix.
536
            # Same holds for `B`.
537
            A_sym = (A + A.mT) / 2
538
            B_sym = (B + B.mT) / 2 if (B is not None) else None
539

540
            return LOBPCGAutogradFunction.apply(
541
                A_sym,
542
                k,
543
                B_sym,
544
                X,
545
                n,
546
                iK,
547
                niter,
548
                tol,
549
                largest,
550
                method,
551
                tracker,
552
                ortho_iparams,
553
                ortho_fparams,
554
                ortho_bparams,
555
            )
556
    else:
557
        if A.requires_grad or (B is not None and B.requires_grad):
558
            raise RuntimeError(
559
                "Script and require grads is not supported atm."
560
                "If you just want to do the forward, use .detach()"
561
                "on A and B before calling into lobpcg"
562
            )
563

564
    return _lobpcg(
565
        A,
566
        k,
567
        B,
568
        X,
569
        n,
570
        iK,
571
        niter,
572
        tol,
573
        largest,
574
        method,
575
        tracker,
576
        ortho_iparams,
577
        ortho_fparams,
578
        ortho_bparams,
579
    )
580

581

582
def _lobpcg(
583
    A: Tensor,
584
    k: Optional[int] = None,
585
    B: Optional[Tensor] = None,
586
    X: Optional[Tensor] = None,
587
    n: Optional[int] = None,
588
    iK: Optional[Tensor] = None,
589
    niter: Optional[int] = None,
590
    tol: Optional[float] = None,
591
    largest: Optional[bool] = None,
592
    method: Optional[str] = None,
593
    tracker: None = None,
594
    ortho_iparams: Optional[Dict[str, int]] = None,
595
    ortho_fparams: Optional[Dict[str, float]] = None,
596
    ortho_bparams: Optional[Dict[str, bool]] = None,
597
) -> Tuple[Tensor, Tensor]:
598
    # A must be square:
599
    assert A.shape[-2] == A.shape[-1], A.shape
600
    if B is not None:
601
        # A and B must have the same shapes:
602
        assert A.shape == B.shape, (A.shape, B.shape)
603

604
    dtype = _utils.get_floating_dtype(A)
605
    device = A.device
606
    if tol is None:
607
        feps = {torch.float32: 1.2e-07, torch.float64: 2.23e-16}[dtype]
608
        tol = feps**0.5
609

610
    m = A.shape[-1]
611
    k = (1 if X is None else X.shape[-1]) if k is None else k
612
    n = (k if n is None else n) if X is None else X.shape[-1]
613

614
    if m < 3 * n:
615
        raise ValueError(
616
            f"LPBPCG algorithm is not applicable when the number of A rows (={m})"
617
            f" is smaller than 3 x the number of requested eigenpairs (={n})"
618
        )
619

620
    method = "ortho" if method is None else method
621

622
    iparams = {
623
        "m": m,
624
        "n": n,
625
        "k": k,
626
        "niter": 1000 if niter is None else niter,
627
    }
628

629
    fparams = {
630
        "tol": tol,
631
    }
632

633
    bparams = {"largest": True if largest is None else largest}
634

635
    if method == "ortho":
636
        if ortho_iparams is not None:
637
            iparams.update(ortho_iparams)
638
        if ortho_fparams is not None:
639
            fparams.update(ortho_fparams)
640
        if ortho_bparams is not None:
641
            bparams.update(ortho_bparams)
642
        iparams["ortho_i_max"] = iparams.get("ortho_i_max", 3)
643
        iparams["ortho_j_max"] = iparams.get("ortho_j_max", 3)
644
        fparams["ortho_tol"] = fparams.get("ortho_tol", tol)
645
        fparams["ortho_tol_drop"] = fparams.get("ortho_tol_drop", tol)
646
        fparams["ortho_tol_replace"] = fparams.get("ortho_tol_replace", tol)
647
        bparams["ortho_use_drop"] = bparams.get("ortho_use_drop", False)
648

649
    if not torch.jit.is_scripting():
650
        LOBPCG.call_tracker = LOBPCG_call_tracker  # type: ignore[method-assign]
651

652
    if len(A.shape) > 2:
653
        N = int(torch.prod(torch.tensor(A.shape[:-2])))
654
        bA = A.reshape((N,) + A.shape[-2:])
655
        bB = B.reshape((N,) + A.shape[-2:]) if B is not None else None
656
        bX = X.reshape((N,) + X.shape[-2:]) if X is not None else None
657
        bE = torch.empty((N, k), dtype=dtype, device=device)
658
        bXret = torch.empty((N, m, k), dtype=dtype, device=device)
659

660
        for i in range(N):
661
            A_ = bA[i]
662
            B_ = bB[i] if bB is not None else None
663
            X_ = (
664
                torch.randn((m, n), dtype=dtype, device=device) if bX is None else bX[i]
665
            )
666
            assert len(X_.shape) == 2 and X_.shape == (m, n), (X_.shape, (m, n))
667
            iparams["batch_index"] = i
668
            worker = LOBPCG(A_, B_, X_, iK, iparams, fparams, bparams, method, tracker)
669
            worker.run()
670
            bE[i] = worker.E[:k]
671
            bXret[i] = worker.X[:, :k]
672

673
        if not torch.jit.is_scripting():
674
            LOBPCG.call_tracker = LOBPCG_call_tracker_orig  # type: ignore[method-assign]
675

676
        return bE.reshape(A.shape[:-2] + (k,)), bXret.reshape(A.shape[:-2] + (m, k))
677

678
    X = torch.randn((m, n), dtype=dtype, device=device) if X is None else X
679
    assert len(X.shape) == 2 and X.shape == (m, n), (X.shape, (m, n))
680

681
    worker = LOBPCG(A, B, X, iK, iparams, fparams, bparams, method, tracker)
682

683
    worker.run()
684

685
    if not torch.jit.is_scripting():
686
        LOBPCG.call_tracker = LOBPCG_call_tracker_orig  # type: ignore[method-assign]
687

688
    return worker.E[:k], worker.X[:, :k]
689

690

691
class LOBPCG:
692
    """Worker class of LOBPCG methods."""
693

694
    def __init__(
695
        self,
696
        A: Optional[Tensor],
697
        B: Optional[Tensor],
698
        X: Tensor,
699
        iK: Optional[Tensor],
700
        iparams: Dict[str, int],
701
        fparams: Dict[str, float],
702
        bparams: Dict[str, bool],
703
        method: str,
704
        tracker: None,
705
    ) -> None:
706
        # constant parameters
707
        self.A = A
708
        self.B = B
709
        self.iK = iK
710
        self.iparams = iparams
711
        self.fparams = fparams
712
        self.bparams = bparams
713
        self.method = method
714
        self.tracker = tracker
715
        m = iparams["m"]
716
        n = iparams["n"]
717

718
        # variable parameters
719
        self.X = X
720
        self.E = torch.zeros((n,), dtype=X.dtype, device=X.device)
721
        self.R = torch.zeros((m, n), dtype=X.dtype, device=X.device)
722
        self.S = torch.zeros((m, 3 * n), dtype=X.dtype, device=X.device)
723
        self.tvars: Dict[str, Tensor] = {}
724
        self.ivars: Dict[str, int] = {"istep": 0}
725
        self.fvars: Dict[str, float] = {"_": 0.0}
726
        self.bvars: Dict[str, bool] = {"_": False}
727

728
    def __str__(self):
729
        lines = ["LOPBCG:"]
730
        lines += [f"  iparams={self.iparams}"]
731
        lines += [f"  fparams={self.fparams}"]
732
        lines += [f"  bparams={self.bparams}"]
733
        lines += [f"  ivars={self.ivars}"]
734
        lines += [f"  fvars={self.fvars}"]
735
        lines += [f"  bvars={self.bvars}"]
736
        lines += [f"  tvars={self.tvars}"]
737
        lines += [f"  A={self.A}"]
738
        lines += [f"  B={self.B}"]
739
        lines += [f"  iK={self.iK}"]
740
        lines += [f"  X={self.X}"]
741
        lines += [f"  E={self.E}"]
742
        r = ""
743
        for line in lines:
744
            r += line + "\n"
745
        return r
746

747
    def update(self):
748
        """Set and update iteration variables."""
749
        if self.ivars["istep"] == 0:
750
            X_norm = float(torch.norm(self.X))
751
            iX_norm = X_norm**-1
752
            A_norm = float(torch.norm(_utils.matmul(self.A, self.X))) * iX_norm
753
            B_norm = float(torch.norm(_utils.matmul(self.B, self.X))) * iX_norm
754
            self.fvars["X_norm"] = X_norm
755
            self.fvars["A_norm"] = A_norm
756
            self.fvars["B_norm"] = B_norm
757
            self.ivars["iterations_left"] = self.iparams["niter"]
758
            self.ivars["converged_count"] = 0
759
            self.ivars["converged_end"] = 0
760

761
        if self.method == "ortho":
762
            self._update_ortho()
763
        else:
764
            self._update_basic()
765

766
        self.ivars["iterations_left"] = self.ivars["iterations_left"] - 1
767
        self.ivars["istep"] = self.ivars["istep"] + 1
768

769
    def update_residual(self):
770
        """Update residual R from A, B, X, E."""
771
        mm = _utils.matmul
772
        self.R = mm(self.A, self.X) - mm(self.B, self.X) * self.E
773

774
    def update_converged_count(self):
775
        """Determine the number of converged eigenpairs using backward stable
776
        convergence criterion, see discussion in Sec 4.3 of [DuerschEtal2018].
777

778
        Users may redefine this method for custom convergence criteria.
779
        """
780
        # (...) -> int
781
        prev_count = self.ivars["converged_count"]
782
        tol = self.fparams["tol"]
783
        A_norm = self.fvars["A_norm"]
784
        B_norm = self.fvars["B_norm"]
785
        E, X, R = self.E, self.X, self.R
786
        rerr = (
787
            torch.norm(R, 2, (0,))
788
            * (torch.norm(X, 2, (0,)) * (A_norm + E[: X.shape[-1]] * B_norm)) ** -1
789
        )
790
        converged = rerr.real < tol  # this is a norm so imag is 0.0
791
        count = 0
792
        for b in converged:
793
            if not b:
794
                # ignore convergence of following pairs to ensure
795
                # strict ordering of eigenpairs
796
                break
797
            count += 1
798
        assert (
799
            count >= prev_count
800
        ), f"the number of converged eigenpairs (was {prev_count}, got {count}) cannot decrease"
801
        self.ivars["converged_count"] = count
802
        self.tvars["rerr"] = rerr
803
        return count
804

805
    def stop_iteration(self):
806
        """Return True to stop iterations.
807

808
        Note that tracker (if defined) can force-stop iterations by
809
        setting ``worker.bvars['force_stop'] = True``.
810
        """
811
        return (
812
            self.bvars.get("force_stop", False)
813
            or self.ivars["iterations_left"] == 0
814
            or self.ivars["converged_count"] >= self.iparams["k"]
815
        )
816

817
    def run(self):
818
        """Run LOBPCG iterations.
819

820
        Use this method as a template for implementing LOBPCG
821
        iteration scheme with custom tracker that is compatible with
822
        TorchScript.
823
        """
824
        self.update()
825

826
        if not torch.jit.is_scripting() and self.tracker is not None:
827
            self.call_tracker()
828

829
        while not self.stop_iteration():
830
            self.update()
831

832
            if not torch.jit.is_scripting() and self.tracker is not None:
833
                self.call_tracker()
834

835
    @torch.jit.unused
836
    def call_tracker(self):
837
        """Interface for tracking iteration process in Python mode.
838

839
        Tracking the iteration process is disabled in TorchScript
840
        mode. In fact, one should specify tracker=None when JIT
841
        compiling functions using lobpcg.
842
        """
843
        # do nothing when in TorchScript mode
844

845
    # Internal methods
846

847
    def _update_basic(self):
848
        """
849
        Update or initialize iteration variables when `method == "basic"`.
850
        """
851
        mm = torch.matmul
852
        ns = self.ivars["converged_end"]
853
        nc = self.ivars["converged_count"]
854
        n = self.iparams["n"]
855
        largest = self.bparams["largest"]
856

857
        if self.ivars["istep"] == 0:
858
            Ri = self._get_rayleigh_ritz_transform(self.X)
859
            M = _utils.qform(_utils.qform(self.A, self.X), Ri)
860
            E, Z = _utils.symeig(M, largest)
861
            self.X[:] = mm(self.X, mm(Ri, Z))
862
            self.E[:] = E
863
            np = 0
864
            self.update_residual()
865
            nc = self.update_converged_count()
866
            self.S[..., :n] = self.X
867

868
            W = _utils.matmul(self.iK, self.R)
869
            self.ivars["converged_end"] = ns = n + np + W.shape[-1]
870
            self.S[:, n + np : ns] = W
871
        else:
872
            S_ = self.S[:, nc:ns]
873
            Ri = self._get_rayleigh_ritz_transform(S_)
874
            M = _utils.qform(_utils.qform(self.A, S_), Ri)
875
            E_, Z = _utils.symeig(M, largest)
876
            self.X[:, nc:] = mm(S_, mm(Ri, Z[:, : n - nc]))
877
            self.E[nc:] = E_[: n - nc]
878
            P = mm(S_, mm(Ri, Z[:, n : 2 * n - nc]))
879
            np = P.shape[-1]
880

881
            self.update_residual()
882
            nc = self.update_converged_count()
883
            self.S[..., :n] = self.X
884
            self.S[:, n : n + np] = P
885
            W = _utils.matmul(self.iK, self.R[:, nc:])
886

887
            self.ivars["converged_end"] = ns = n + np + W.shape[-1]
888
            self.S[:, n + np : ns] = W
889

890
    def _update_ortho(self):
891
        """
892
        Update or initialize iteration variables when `method == "ortho"`.
893
        """
894
        mm = torch.matmul
895
        ns = self.ivars["converged_end"]
896
        nc = self.ivars["converged_count"]
897
        n = self.iparams["n"]
898
        largest = self.bparams["largest"]
899

900
        if self.ivars["istep"] == 0:
901
            Ri = self._get_rayleigh_ritz_transform(self.X)
902
            M = _utils.qform(_utils.qform(self.A, self.X), Ri)
903
            E, Z = _utils.symeig(M, largest)
904
            self.X = mm(self.X, mm(Ri, Z))
905
            self.update_residual()
906
            np = 0
907
            nc = self.update_converged_count()
908
            self.S[:, :n] = self.X
909
            W = self._get_ortho(self.R, self.X)
910
            ns = self.ivars["converged_end"] = n + np + W.shape[-1]
911
            self.S[:, n + np : ns] = W
912

913
        else:
914
            S_ = self.S[:, nc:ns]
915
            # Rayleigh-Ritz procedure
916
            E_, Z = _utils.symeig(_utils.qform(self.A, S_), largest)
917

918
            # Update E, X, P
919
            self.X[:, nc:] = mm(S_, Z[:, : n - nc])
920
            self.E[nc:] = E_[: n - nc]
921
            P = mm(S_, mm(Z[:, n - nc :], _utils.basis(Z[: n - nc, n - nc :].mT)))
922
            np = P.shape[-1]
923

924
            # check convergence
925
            self.update_residual()
926
            nc = self.update_converged_count()
927

928
            # update S
929
            self.S[:, :n] = self.X
930
            self.S[:, n : n + np] = P
931
            W = self._get_ortho(self.R[:, nc:], self.S[:, : n + np])
932
            ns = self.ivars["converged_end"] = n + np + W.shape[-1]
933
            self.S[:, n + np : ns] = W
934

935
    def _get_rayleigh_ritz_transform(self, S):
936
        """Return a transformation matrix that is used in Rayleigh-Ritz
937
        procedure for reducing a general eigenvalue problem :math:`(S^TAS)
938
        C = (S^TBS) C E` to a standard eigenvalue problem :math: `(Ri^T
939
        S^TAS Ri) Z = Z E` where `C = Ri Z`.
940

941
        .. note:: In the original Rayleight-Ritz procedure in
942
          [DuerschEtal2018], the problem is formulated as follows::
943

944
            SAS = S^T A S
945
            SBS = S^T B S
946
            D = (<diagonal matrix of SBS>) ** -1/2
947
            R^T R = Cholesky(D SBS D)
948
            Ri = D R^-1
949
            solve symeig problem Ri^T SAS Ri Z = Theta Z
950
            C = Ri Z
951

952
          To reduce the number of matrix products (denoted by empty
953
          space between matrices), here we introduce element-wise
954
          products (denoted by symbol `*`) so that the Rayleight-Ritz
955
          procedure becomes::
956

957
            SAS = S^T A S
958
            SBS = S^T B S
959
            d = (<diagonal of SBS>) ** -1/2    # this is 1-d column vector
960
            dd = d d^T                         # this is 2-d matrix
961
            R^T R = Cholesky(dd * SBS)
962
            Ri = R^-1 * d                      # broadcasting
963
            solve symeig problem Ri^T SAS Ri Z = Theta Z
964
            C = Ri Z
965

966
          where `dd` is 2-d matrix that replaces matrix products `D M
967
          D` with one element-wise product `M * dd`; and `d` replaces
968
          matrix product `D M` with element-wise product `M *
969
          d`. Also, creating the diagonal matrix `D` is avoided.
970

971
        Args:
972
        S (Tensor): the matrix basis for the search subspace, size is
973
                    :math:`(m, n)`.
974

975
        Returns:
976
        Ri (tensor): upper-triangular transformation matrix of size
977
                     :math:`(n, n)`.
978

979
        """
980
        B = self.B
981
        mm = torch.matmul
982
        SBS = _utils.qform(B, S)
983
        d_row = SBS.diagonal(0, -2, -1) ** -0.5
984
        d_col = d_row.reshape(d_row.shape[0], 1)
985
        # TODO use torch.linalg.cholesky_solve once it is implemented
986
        R = torch.linalg.cholesky((SBS * d_row) * d_col, upper=True)
987
        return torch.linalg.solve_triangular(
988
            R, d_row.diag_embed(), upper=True, left=False
989
        )
990

991
    def _get_svqb(self, U: Tensor, drop: bool, tau: float) -> Tensor:
992
        """Return B-orthonormal U.
993

994
        .. note:: When `drop` is `False` then `svqb` is based on the
995
                  Algorithm 4 from [DuerschPhD2015] that is a slight
996
                  modification of the corresponding algorithm
997
                  introduced in [StathopolousWu2002].
998

999
        Args:
1000

1001
          U (Tensor) : initial approximation, size is (m, n)
1002
          drop (bool) : when True, drop columns that
1003
                     contribution to the `span([U])` is small.
1004
          tau (float) : positive tolerance
1005

1006
        Returns:
1007

1008
          U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`), size
1009
                       is (m, n1), where `n1 = n` if `drop` is `False,
1010
                       otherwise `n1 <= n`.
1011

1012
        """
1013
        if torch.numel(U) == 0:
1014
            return U
1015
        UBU = _utils.qform(self.B, U)
1016
        d = UBU.diagonal(0, -2, -1)
1017

1018
        # Detect and drop exact zero columns from U. While the test
1019
        # `abs(d) == 0` is unlikely to be True for random data, it is
1020
        # possible to construct input data to lobpcg where it will be
1021
        # True leading to a failure (notice the `d ** -0.5` operation
1022
        # in the original algorithm). To prevent the failure, we drop
1023
        # the exact zero columns here and then continue with the
1024
        # original algorithm below.
1025
        nz = torch.where(abs(d) != 0.0)
1026
        assert len(nz) == 1, nz
1027
        if len(nz[0]) < len(d):
1028
            U = U[:, nz[0]]
1029
            if torch.numel(U) == 0:
1030
                return U
1031
            UBU = _utils.qform(self.B, U)
1032
            d = UBU.diagonal(0, -2, -1)
1033
            nz = torch.where(abs(d) != 0.0)
1034
            assert len(nz[0]) == len(d)
1035

1036
        # The original algorithm 4 from [DuerschPhD2015].
1037
        d_col = (d**-0.5).reshape(d.shape[0], 1)
1038
        DUBUD = (UBU * d_col) * d_col.mT
1039
        E, Z = _utils.symeig(DUBUD)
1040
        t = tau * abs(E).max()
1041
        if drop:
1042
            keep = torch.where(E > t)
1043
            assert len(keep) == 1, keep
1044
            E = E[keep[0]]
1045
            Z = Z[:, keep[0]]
1046
            d_col = d_col[keep[0]]
1047
        else:
1048
            E[(torch.where(E < t))[0]] = t
1049

1050
        return torch.matmul(U * d_col.mT, Z * E**-0.5)
1051

1052
    def _get_ortho(self, U, V):
1053
        """Return B-orthonormal U with columns are B-orthogonal to V.
1054

1055
        .. note:: When `bparams["ortho_use_drop"] == False` then
1056
                  `_get_ortho` is based on the Algorithm 3 from
1057
                  [DuerschPhD2015] that is a slight modification of
1058
                  the corresponding algorithm introduced in
1059
                  [StathopolousWu2002]. Otherwise, the method
1060
                  implements Algorithm 6 from [DuerschPhD2015]
1061

1062
        .. note:: If all U columns are B-collinear to V then the
1063
                  returned tensor U will be empty.
1064

1065
        Args:
1066

1067
          U (Tensor) : initial approximation, size is (m, n)
1068
          V (Tensor) : B-orthogonal external basis, size is (m, k)
1069

1070
        Returns:
1071

1072
          U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`)
1073
                       such that :math:`V^T B U=0`, size is (m, n1),
1074
                       where `n1 = n` if `drop` is `False, otherwise
1075
                       `n1 <= n`.
1076
        """
1077
        mm = torch.matmul
1078
        mm_B = _utils.matmul
1079
        m = self.iparams["m"]
1080
        tau_ortho = self.fparams["ortho_tol"]
1081
        tau_drop = self.fparams["ortho_tol_drop"]
1082
        tau_replace = self.fparams["ortho_tol_replace"]
1083
        i_max = self.iparams["ortho_i_max"]
1084
        j_max = self.iparams["ortho_j_max"]
1085
        # when use_drop==True, enable dropping U columns that have
1086
        # small contribution to the `span([U, V])`.
1087
        use_drop = self.bparams["ortho_use_drop"]
1088

1089
        # clean up variables from the previous call
1090
        for vkey in list(self.fvars.keys()):
1091
            if vkey.startswith("ortho_") and vkey.endswith("_rerr"):
1092
                self.fvars.pop(vkey)
1093
        self.ivars.pop("ortho_i", 0)
1094
        self.ivars.pop("ortho_j", 0)
1095

1096
        BV_norm = torch.norm(mm_B(self.B, V))
1097
        BU = mm_B(self.B, U)
1098
        VBU = mm(V.mT, BU)
1099
        i = j = 0
1100
        stats = ""
1101
        for i in range(i_max):
1102
            U = U - mm(V, VBU)
1103
            drop = False
1104
            tau_svqb = tau_drop
1105
            for j in range(j_max):
1106
                if use_drop:
1107
                    U = self._get_svqb(U, drop, tau_svqb)
1108
                    drop = True
1109
                    tau_svqb = tau_replace
1110
                else:
1111
                    U = self._get_svqb(U, False, tau_replace)
1112
                if torch.numel(U) == 0:
1113
                    # all initial U columns are B-collinear to V
1114
                    self.ivars["ortho_i"] = i
1115
                    self.ivars["ortho_j"] = j
1116
                    return U
1117
                BU = mm_B(self.B, U)
1118
                UBU = mm(U.mT, BU)
1119
                U_norm = torch.norm(U)
1120
                BU_norm = torch.norm(BU)
1121
                R = UBU - torch.eye(UBU.shape[-1], device=UBU.device, dtype=UBU.dtype)
1122
                R_norm = torch.norm(R)
1123
                # https://github.com/pytorch/pytorch/issues/33810 workaround:
1124
                rerr = float(R_norm) * float(BU_norm * U_norm) ** -1
1125
                vkey = f"ortho_UBUmI_rerr[{i}, {j}]"
1126
                self.fvars[vkey] = rerr
1127
                if rerr < tau_ortho:
1128
                    break
1129
            VBU = mm(V.mT, BU)
1130
            VBU_norm = torch.norm(VBU)
1131
            U_norm = torch.norm(U)
1132
            rerr = float(VBU_norm) * float(BV_norm * U_norm) ** -1
1133
            vkey = f"ortho_VBU_rerr[{i}]"
1134
            self.fvars[vkey] = rerr
1135
            if rerr < tau_ortho:
1136
                break
1137
            if m < U.shape[-1] + V.shape[-1]:
1138
                # TorchScript needs the class var to be assigned to a local to
1139
                # do optional type refinement
1140
                B = self.B
1141
                assert B is not None
1142
                raise ValueError(
1143
                    "Overdetermined shape of U:"
1144
                    f" #B-cols(={B.shape[-1]}) >= #U-cols(={U.shape[-1]}) + #V-cols(={V.shape[-1]}) must hold"
1145
                )
1146
        self.ivars["ortho_i"] = i
1147
        self.ivars["ortho_j"] = j
1148
        return U
1149

1150

1151
# Calling tracker is separated from LOBPCG definitions because
1152
# TorchScript does not support user-defined callback arguments:
1153
LOBPCG_call_tracker_orig = LOBPCG.call_tracker
1154

1155

1156
def LOBPCG_call_tracker(self):
1157
    self.tracker(self)
1158

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.