pytorch-lightning
36 строк · 1.2 Кб
1import collections2import dataclasses3
4import torch5from lightning.fabric.utilities.optimizer import _optimizer_to_device6from torch import Tensor7
8
9def test_optimizer_to_device():10@dataclasses.dataclass(frozen=True)11class FooState:12bar: int13
14class TestOptimizer(torch.optim.SGD):15def __init__(self, *args, **kwargs):16super().__init__(*args, **kwargs)17self.state["dummy"] = torch.tensor(0)18self.state["frozen"] = FooState(0)19
20layer = torch.nn.Linear(32, 2)21opt = TestOptimizer(layer.parameters(), lr=0.1)22_optimizer_to_device(opt, "cpu")23if torch.cuda.is_available():24_optimizer_to_device(opt, "cuda")25assert_opt_parameters_on_device(opt, "cuda")26
27
28def assert_opt_parameters_on_device(opt, device: str):29for param in opt.state.values():30# Not sure there are any global tensors in the state dict31if isinstance(param, Tensor):32assert param.data.device.type == device33elif isinstance(param, collections.abc.Mapping):34for subparam in param.values():35if isinstance(subparam, Tensor):36assert param.data.device.type == device37