pytorch-lightning

Форк
0
36 строк · 1.2 Кб
1
import collections
2
import dataclasses
3

4
import torch
5
from lightning.fabric.utilities.optimizer import _optimizer_to_device
6
from torch import Tensor
7

8

9
def test_optimizer_to_device():
10
    @dataclasses.dataclass(frozen=True)
11
    class FooState:
12
        bar: int
13

14
    class TestOptimizer(torch.optim.SGD):
15
        def __init__(self, *args, **kwargs):
16
            super().__init__(*args, **kwargs)
17
            self.state["dummy"] = torch.tensor(0)
18
            self.state["frozen"] = FooState(0)
19

20
    layer = torch.nn.Linear(32, 2)
21
    opt = TestOptimizer(layer.parameters(), lr=0.1)
22
    _optimizer_to_device(opt, "cpu")
23
    if torch.cuda.is_available():
24
        _optimizer_to_device(opt, "cuda")
25
        assert_opt_parameters_on_device(opt, "cuda")
26

27

28
def assert_opt_parameters_on_device(opt, device: str):
29
    for param in opt.state.values():
30
        # Not sure there are any global tensors in the state dict
31
        if isinstance(param, Tensor):
32
            assert param.data.device.type == device
33
        elif isinstance(param, collections.abc.Mapping):
34
            for subparam in param.values():
35
                if isinstance(subparam, Tensor):
36
                    assert param.data.device.type == device
37

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.