pytorch-lightning
257 строк · 7.6 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14from typing import Any, Dict, Iterator, List, Optional, Tuple15
16import torch17import torch.nn as nn18import torch.nn.functional as F19from torch import Tensor20from torch.optim import Optimizer21from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset22
23from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER24from lightning.pytorch import LightningDataModule, LightningModule25from lightning.pytorch.core.optimizer import LightningOptimizer26from lightning.pytorch.utilities.types import STEP_OUTPUT27
28
29class RandomDictDataset(Dataset):30"""31.. warning:: This is meant for testing/debugging and is experimental.
32"""
33
34def __init__(self, size: int, length: int):35self.len = length36self.data = torch.randn(length, size)37
38def __getitem__(self, index: int) -> Dict[str, Tensor]:39a = self.data[index]40b = a + 241return {"a": a, "b": b}42
43def __len__(self) -> int:44return self.len45
46
47class RandomDataset(Dataset):48"""49.. warning:: This is meant for testing/debugging and is experimental.
50"""
51
52def __init__(self, size: int, length: int):53self.len = length54self.data = torch.randn(length, size)55
56def __getitem__(self, index: int) -> Tensor:57return self.data[index]58
59def __len__(self) -> int:60return self.len61
62
63class RandomIterableDataset(IterableDataset):64"""65.. warning:: This is meant for testing/debugging and is experimental.
66"""
67
68def __init__(self, size: int, count: int):69self.count = count70self.size = size71
72def __iter__(self) -> Iterator[Tensor]:73for _ in range(self.count):74yield torch.randn(self.size)75
76
77class RandomIterableDatasetWithLen(IterableDataset):78"""79.. warning:: This is meant for testing/debugging and is experimental.
80"""
81
82def __init__(self, size: int, count: int):83self.count = count84self.size = size85
86def __iter__(self) -> Iterator[Tensor]:87for _ in range(len(self)):88yield torch.randn(self.size)89
90def __len__(self) -> int:91return self.count92
93
94class BoringModel(LightningModule):95"""Testing PL Module.96
97Use as follows:
98- subclass
99- modify the behavior for what you want
100
101.. warning:: This is meant for testing/debugging and is experimental.
102
103Example::
104
105class TestModel(BoringModel):
106def training_step(self, ...):
107... # do your own thing
108
109"""
110
111def __init__(self) -> None:112super().__init__()113self.layer = torch.nn.Linear(32, 2)114
115def forward(self, x: Tensor) -> Tensor:116return self.layer(x)117
118def loss(self, preds: Tensor, labels: Optional[Tensor] = None) -> Tensor:119if labels is None:120labels = torch.ones_like(preds)121# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls122return torch.nn.functional.mse_loss(preds, labels)123
124def step(self, batch: Any) -> Tensor:125output = self(batch)126return self.loss(output)127
128def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:129return {"loss": self.step(batch)}130
131def validation_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:132return {"x": self.step(batch)}133
134def test_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:135return {"y": self.step(batch)}136
137def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_TORCH_LRSCHEDULER]]:138optimizer = torch.optim.SGD(self.parameters(), lr=0.1)139lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)140return [optimizer], [lr_scheduler]141
142def train_dataloader(self) -> DataLoader:143return DataLoader(RandomDataset(32, 64))144
145def val_dataloader(self) -> DataLoader:146return DataLoader(RandomDataset(32, 64))147
148def test_dataloader(self) -> DataLoader:149return DataLoader(RandomDataset(32, 64))150
151def predict_dataloader(self) -> DataLoader:152return DataLoader(RandomDataset(32, 64))153
154
155class BoringDataModule(LightningDataModule):156"""157.. warning:: This is meant for testing/debugging and is experimental.
158"""
159
160def __init__(self) -> None:161super().__init__()162self.random_full = RandomDataset(32, 64 * 4)163
164def setup(self, stage: str) -> None:165if stage == "fit":166self.random_train = Subset(self.random_full, indices=range(64))167
168if stage in ("fit", "validate"):169self.random_val = Subset(self.random_full, indices=range(64, 64 * 2))170
171if stage == "test":172self.random_test = Subset(self.random_full, indices=range(64 * 2, 64 * 3))173
174if stage == "predict":175self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4))176
177def train_dataloader(self) -> DataLoader:178return DataLoader(self.random_train)179
180def val_dataloader(self) -> DataLoader:181return DataLoader(self.random_val)182
183def test_dataloader(self) -> DataLoader:184return DataLoader(self.random_test)185
186def predict_dataloader(self) -> DataLoader:187return DataLoader(self.random_predict)188
189
190class ManualOptimBoringModel(BoringModel):191"""192.. warning:: This is meant for testing/debugging and is experimental.
193"""
194
195def __init__(self) -> None:196super().__init__()197self.automatic_optimization = False198
199def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:200opt = self.optimizers()201assert isinstance(opt, (Optimizer, LightningOptimizer))202loss = self.step(batch)203opt.zero_grad()204self.manual_backward(loss)205opt.step()206return loss207
208
209class DemoModel(LightningModule):210"""211.. warning:: This is meant for testing/debugging and is experimental.
212"""
213
214def __init__(self, out_dim: int = 10, learning_rate: float = 0.02):215super().__init__()216self.l1 = torch.nn.Linear(32, out_dim)217self.learning_rate = learning_rate218
219def forward(self, x: Tensor) -> Tensor:220return torch.relu(self.l1(x.view(x.size(0), -1)))221
222def training_step(self, batch: Any, batch_nb: int) -> STEP_OUTPUT:223x = batch224x = self(x)225return x.sum()226
227def configure_optimizers(self) -> torch.optim.Optimizer:228return torch.optim.Adam(self.parameters(), lr=self.learning_rate)229
230
231class Net(nn.Module):232"""233.. warning:: This is meant for testing/debugging and is experimental.
234"""
235
236def __init__(self) -> None:237super().__init__()238self.conv1 = nn.Conv2d(1, 32, 3, 1)239self.conv2 = nn.Conv2d(32, 64, 3, 1)240self.dropout1 = nn.Dropout(0.25)241self.dropout2 = nn.Dropout(0.5)242self.fc1 = nn.Linear(9216, 128)243self.fc2 = nn.Linear(128, 10)244
245def forward(self, x: Tensor) -> Tensor:246x = self.conv1(x)247x = F.relu(x)248x = self.conv2(x)249x = F.relu(x)250x = F.max_pool2d(x, 2)251x = self.dropout1(x)252x = torch.flatten(x, 1)253x = self.fc1(x)254x = F.relu(x)255x = self.dropout2(x)256x = self.fc2(x)257return F.log_softmax(x, dim=1)258