pytorch-lightning

Форк
0
257 строк · 7.6 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
from typing import Any, Dict, Iterator, List, Optional, Tuple
15

16
import torch
17
import torch.nn as nn
18
import torch.nn.functional as F
19
from torch import Tensor
20
from torch.optim import Optimizer
21
from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset
22

23
from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER
24
from lightning.pytorch import LightningDataModule, LightningModule
25
from lightning.pytorch.core.optimizer import LightningOptimizer
26
from lightning.pytorch.utilities.types import STEP_OUTPUT
27

28

29
class RandomDictDataset(Dataset):
30
    """
31
    .. warning::  This is meant for testing/debugging and is experimental.
32
    """
33

34
    def __init__(self, size: int, length: int):
35
        self.len = length
36
        self.data = torch.randn(length, size)
37

38
    def __getitem__(self, index: int) -> Dict[str, Tensor]:
39
        a = self.data[index]
40
        b = a + 2
41
        return {"a": a, "b": b}
42

43
    def __len__(self) -> int:
44
        return self.len
45

46

47
class RandomDataset(Dataset):
48
    """
49
    .. warning::  This is meant for testing/debugging and is experimental.
50
    """
51

52
    def __init__(self, size: int, length: int):
53
        self.len = length
54
        self.data = torch.randn(length, size)
55

56
    def __getitem__(self, index: int) -> Tensor:
57
        return self.data[index]
58

59
    def __len__(self) -> int:
60
        return self.len
61

62

63
class RandomIterableDataset(IterableDataset):
64
    """
65
    .. warning::  This is meant for testing/debugging and is experimental.
66
    """
67

68
    def __init__(self, size: int, count: int):
69
        self.count = count
70
        self.size = size
71

72
    def __iter__(self) -> Iterator[Tensor]:
73
        for _ in range(self.count):
74
            yield torch.randn(self.size)
75

76

77
class RandomIterableDatasetWithLen(IterableDataset):
78
    """
79
    .. warning::  This is meant for testing/debugging and is experimental.
80
    """
81

82
    def __init__(self, size: int, count: int):
83
        self.count = count
84
        self.size = size
85

86
    def __iter__(self) -> Iterator[Tensor]:
87
        for _ in range(len(self)):
88
            yield torch.randn(self.size)
89

90
    def __len__(self) -> int:
91
        return self.count
92

93

94
class BoringModel(LightningModule):
95
    """Testing PL Module.
96

97
    Use as follows:
98
    - subclass
99
    - modify the behavior for what you want
100

101
    .. warning::  This is meant for testing/debugging and is experimental.
102

103
    Example::
104

105
        class TestModel(BoringModel):
106
            def training_step(self, ...):
107
                ...  # do your own thing
108

109
    """
110

111
    def __init__(self) -> None:
112
        super().__init__()
113
        self.layer = torch.nn.Linear(32, 2)
114

115
    def forward(self, x: Tensor) -> Tensor:
116
        return self.layer(x)
117

118
    def loss(self, preds: Tensor, labels: Optional[Tensor] = None) -> Tensor:
119
        if labels is None:
120
            labels = torch.ones_like(preds)
121
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
122
        return torch.nn.functional.mse_loss(preds, labels)
123

124
    def step(self, batch: Any) -> Tensor:
125
        output = self(batch)
126
        return self.loss(output)
127

128
    def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
129
        return {"loss": self.step(batch)}
130

131
    def validation_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
132
        return {"x": self.step(batch)}
133

134
    def test_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
135
        return {"y": self.step(batch)}
136

137
    def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_TORCH_LRSCHEDULER]]:
138
        optimizer = torch.optim.SGD(self.parameters(), lr=0.1)
139
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
140
        return [optimizer], [lr_scheduler]
141

142
    def train_dataloader(self) -> DataLoader:
143
        return DataLoader(RandomDataset(32, 64))
144

145
    def val_dataloader(self) -> DataLoader:
146
        return DataLoader(RandomDataset(32, 64))
147

148
    def test_dataloader(self) -> DataLoader:
149
        return DataLoader(RandomDataset(32, 64))
150

151
    def predict_dataloader(self) -> DataLoader:
152
        return DataLoader(RandomDataset(32, 64))
153

154

155
class BoringDataModule(LightningDataModule):
156
    """
157
    .. warning::  This is meant for testing/debugging and is experimental.
158
    """
159

160
    def __init__(self) -> None:
161
        super().__init__()
162
        self.random_full = RandomDataset(32, 64 * 4)
163

164
    def setup(self, stage: str) -> None:
165
        if stage == "fit":
166
            self.random_train = Subset(self.random_full, indices=range(64))
167

168
        if stage in ("fit", "validate"):
169
            self.random_val = Subset(self.random_full, indices=range(64, 64 * 2))
170

171
        if stage == "test":
172
            self.random_test = Subset(self.random_full, indices=range(64 * 2, 64 * 3))
173

174
        if stage == "predict":
175
            self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4))
176

177
    def train_dataloader(self) -> DataLoader:
178
        return DataLoader(self.random_train)
179

180
    def val_dataloader(self) -> DataLoader:
181
        return DataLoader(self.random_val)
182

183
    def test_dataloader(self) -> DataLoader:
184
        return DataLoader(self.random_test)
185

186
    def predict_dataloader(self) -> DataLoader:
187
        return DataLoader(self.random_predict)
188

189

190
class ManualOptimBoringModel(BoringModel):
191
    """
192
    .. warning::  This is meant for testing/debugging and is experimental.
193
    """
194

195
    def __init__(self) -> None:
196
        super().__init__()
197
        self.automatic_optimization = False
198

199
    def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT:
200
        opt = self.optimizers()
201
        assert isinstance(opt, (Optimizer, LightningOptimizer))
202
        loss = self.step(batch)
203
        opt.zero_grad()
204
        self.manual_backward(loss)
205
        opt.step()
206
        return loss
207

208

209
class DemoModel(LightningModule):
210
    """
211
    .. warning::  This is meant for testing/debugging and is experimental.
212
    """
213

214
    def __init__(self, out_dim: int = 10, learning_rate: float = 0.02):
215
        super().__init__()
216
        self.l1 = torch.nn.Linear(32, out_dim)
217
        self.learning_rate = learning_rate
218

219
    def forward(self, x: Tensor) -> Tensor:
220
        return torch.relu(self.l1(x.view(x.size(0), -1)))
221

222
    def training_step(self, batch: Any, batch_nb: int) -> STEP_OUTPUT:
223
        x = batch
224
        x = self(x)
225
        return x.sum()
226

227
    def configure_optimizers(self) -> torch.optim.Optimizer:
228
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
229

230

231
class Net(nn.Module):
232
    """
233
    .. warning::  This is meant for testing/debugging and is experimental.
234
    """
235

236
    def __init__(self) -> None:
237
        super().__init__()
238
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
239
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
240
        self.dropout1 = nn.Dropout(0.25)
241
        self.dropout2 = nn.Dropout(0.5)
242
        self.fc1 = nn.Linear(9216, 128)
243
        self.fc2 = nn.Linear(128, 10)
244

245
    def forward(self, x: Tensor) -> Tensor:
246
        x = self.conv1(x)
247
        x = F.relu(x)
248
        x = self.conv2(x)
249
        x = F.relu(x)
250
        x = F.max_pool2d(x, 2)
251
        x = self.dropout1(x)
252
        x = torch.flatten(x, 1)
253
        x = self.fc1(x)
254
        x = F.relu(x)
255
        x = self.dropout2(x)
256
        x = self.fc2(x)
257
        return F.log_softmax(x, dim=1)
258

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.