pytorch-lightning

Форк
0
/
computer_vision_fine_tuning.py 
283 строки · 10.1 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""Computer vision example on Transfer Learning. This computer vision example illustrates how one could fine-tune a
15
pre-trained network (by default, a ResNet50 is used) using pytorch-lightning. For the sake of this example, the 'cats
16
and dogs dataset' (~60MB, see `DATA_URL` below) and the proposed network (denoted by `TransferLearningModel`, see
17
below) is trained for 15 epochs.
18

19
The training consists of three stages.
20

21
From epoch 0 to 4, the feature extractor (the pre-trained network) is frozen except
22
maybe for the BatchNorm layers (depending on whether `train_bn = True`). The BatchNorm
23
layers (if `train_bn = True`) and the parameters of the classifier are trained as a
24
single parameters group with lr = 1e-2.
25

26
From epoch 5 to 9, the last two layer groups of the pre-trained network are unfrozen
27
and added to the optimizer as a new parameter group with lr = 1e-4 (while lr = 1e-3
28
for the first parameter group in the optimizer).
29

30
Eventually, from epoch 10, all the remaining layer groups of the pre-trained network
31
are unfrozen and added to the optimizer as a third parameter group. From epoch 10,
32
the parameters of the pre-trained network are trained with lr = 1e-5 while those of
33
the classifier is trained with lr = 1e-4.
34

35
Note:
36
    See: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
37

38
To run:
39
    python computer_vision_fine_tuning.py fit
40

41
"""
42

43
import logging
44
from pathlib import Path
45
from typing import Union
46

47
import torch
48
import torch.nn.functional as F
49
from lightning.pytorch import LightningDataModule, LightningModule, cli_lightning_logo
50
from lightning.pytorch.callbacks.finetuning import BaseFinetuning
51
from lightning.pytorch.cli import LightningCLI
52
from lightning.pytorch.utilities import rank_zero_info
53
from lightning.pytorch.utilities.model_helpers import get_torchvision_model
54
from torch import nn, optim
55
from torch.optim.lr_scheduler import MultiStepLR
56
from torch.optim.optimizer import Optimizer
57
from torch.utils.data import DataLoader
58
from torchmetrics import Accuracy
59
from torchvision import transforms
60
from torchvision.datasets import ImageFolder
61
from torchvision.datasets.utils import download_and_extract_archive
62

63
log = logging.getLogger(__name__)
64
DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"
65

66
#  --- Finetuning Callback ---
67

68

69
class MilestonesFinetuning(BaseFinetuning):
70
    def __init__(self, milestones: tuple = (5, 10), train_bn: bool = False):
71
        super().__init__()
72
        self.milestones = milestones
73
        self.train_bn = train_bn
74

75
    def freeze_before_training(self, pl_module: LightningModule):
76
        self.freeze(modules=pl_module.feature_extractor, train_bn=self.train_bn)
77

78
    def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer):
79
        if epoch == self.milestones[0]:
80
            # unfreeze 5 last layers
81
            self.unfreeze_and_add_param_group(
82
                modules=pl_module.feature_extractor[-5:], optimizer=optimizer, train_bn=self.train_bn
83
            )
84

85
        elif epoch == self.milestones[1]:
86
            # unfreeze remaining layers
87
            self.unfreeze_and_add_param_group(
88
                modules=pl_module.feature_extractor[:-5], optimizer=optimizer, train_bn=self.train_bn
89
            )
90

91

92
class CatDogImageDataModule(LightningDataModule):
93
    def __init__(self, dl_path: Union[str, Path] = "data", num_workers: int = 0, batch_size: int = 8):
94
        """CatDogImageDataModule.
95

96
        Args:
97
            dl_path: root directory where to download the data
98
            num_workers: number of CPU workers
99
            batch_size: number of sample in a batch
100

101
        """
102
        super().__init__()
103

104
        self._dl_path = dl_path
105
        self._num_workers = num_workers
106
        self._batch_size = batch_size
107

108
    def prepare_data(self):
109
        """Download images and prepare images datasets."""
110
        download_and_extract_archive(url=DATA_URL, download_root=self._dl_path, remove_finished=True)
111

112
    @property
113
    def data_path(self):
114
        return Path(self._dl_path).joinpath("cats_and_dogs_filtered")
115

116
    @property
117
    def normalize_transform(self):
118
        return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
119

120
    @property
121
    def train_transform(self):
122
        return transforms.Compose([
123
            transforms.Resize((224, 224)),
124
            transforms.RandomHorizontalFlip(),
125
            transforms.ToTensor(),
126
            self.normalize_transform,
127
        ])
128

129
    @property
130
    def valid_transform(self):
131
        return transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), self.normalize_transform])
132

133
    def create_dataset(self, root, transform):
134
        return ImageFolder(root=root, transform=transform)
135

136
    def __dataloader(self, train: bool):
137
        """Train/validation loaders."""
138
        if train:
139
            dataset = self.create_dataset(self.data_path.joinpath("train"), self.train_transform)
140
        else:
141
            dataset = self.create_dataset(self.data_path.joinpath("validation"), self.valid_transform)
142
        return DataLoader(dataset=dataset, batch_size=self._batch_size, num_workers=self._num_workers, shuffle=train)
143

144
    def train_dataloader(self):
145
        log.info("Training data loaded.")
146
        return self.__dataloader(train=True)
147

148
    def val_dataloader(self):
149
        log.info("Validation data loaded.")
150
        return self.__dataloader(train=False)
151

152

153
#  --- PyTorch Lightning module ---
154

155

156
class TransferLearningModel(LightningModule):
157
    def __init__(
158
        self,
159
        backbone: str = "resnet50",
160
        train_bn: bool = False,
161
        milestones: tuple = (2, 4),
162
        batch_size: int = 32,
163
        lr: float = 1e-3,
164
        lr_scheduler_gamma: float = 1e-1,
165
        num_workers: int = 6,
166
        **kwargs,
167
    ) -> None:
168
        """TransferLearningModel.
169

170
        Args:
171
            backbone: Name (as in ``torchvision.models``) of the feature extractor
172
            train_bn: Whether the BatchNorm layers should be trainable
173
            milestones: List of two epochs milestones
174
            lr: Initial learning rate
175
            lr_scheduler_gamma: Factor by which the learning rate is reduced at each milestone
176

177
        """
178
        super().__init__()
179
        self.backbone = backbone
180
        self.train_bn = train_bn
181
        self.milestones = milestones
182
        self.batch_size = batch_size
183
        self.lr = lr
184
        self.lr_scheduler_gamma = lr_scheduler_gamma
185
        self.num_workers = num_workers
186

187
        self.__build_model()
188

189
        self.train_acc = Accuracy(task="binary")
190
        self.valid_acc = Accuracy(task="binary")
191
        self.save_hyperparameters()
192

193
    def __build_model(self):
194
        """Define model layers & loss."""
195
        # 1. Load pre-trained network:
196
        backbone = get_torchvision_model(self.backbone, weights="DEFAULT")
197

198
        _layers = list(backbone.children())[:-1]
199
        self.feature_extractor = nn.Sequential(*_layers)
200

201
        # 2. Classifier:
202
        _fc_layers = [nn.Linear(2048, 256), nn.ReLU(), nn.Linear(256, 32), nn.Linear(32, 1)]
203
        self.fc = nn.Sequential(*_fc_layers)
204

205
        # 3. Loss:
206
        self.loss_func = F.binary_cross_entropy_with_logits
207

208
    def forward(self, x):
209
        """Forward pass.
210

211
        Returns logits.
212

213
        """
214
        # 1. Feature extraction:
215
        x = self.feature_extractor(x)
216
        x = x.squeeze(-1).squeeze(-1)
217

218
        # 2. Classifier (returns logits):
219
        return self.fc(x)
220

221
    def loss(self, logits, labels):
222
        return self.loss_func(input=logits, target=labels)
223

224
    def training_step(self, batch, batch_idx):
225
        # 1. Forward pass:
226
        x, y = batch
227
        y_logits = self.forward(x)
228
        y_scores = torch.sigmoid(y_logits)
229
        y_true = y.view((-1, 1)).type_as(x)
230

231
        # 2. Compute loss
232
        train_loss = self.loss(y_logits, y_true)
233

234
        # 3. Compute accuracy:
235
        self.log("train_acc", self.train_acc(y_scores, y_true.int()), prog_bar=True)
236

237
        return train_loss
238

239
    def validation_step(self, batch, batch_idx):
240
        # 1. Forward pass:
241
        x, y = batch
242
        y_logits = self.forward(x)
243
        y_scores = torch.sigmoid(y_logits)
244
        y_true = y.view((-1, 1)).type_as(x)
245

246
        # 2. Compute loss
247
        self.log("val_loss", self.loss(y_logits, y_true), prog_bar=True)
248

249
        # 3. Compute accuracy:
250
        self.log("val_acc", self.valid_acc(y_scores, y_true.int()), prog_bar=True)
251

252
    def configure_optimizers(self):
253
        parameters = list(self.parameters())
254
        trainable_parameters = list(filter(lambda p: p.requires_grad, parameters))
255
        rank_zero_info(
256
            f"The model will start training with only {len(trainable_parameters)} "
257
            f"trainable parameters out of {len(parameters)}."
258
        )
259
        optimizer = optim.Adam(trainable_parameters, lr=self.lr)
260
        scheduler = MultiStepLR(optimizer, milestones=self.milestones, gamma=self.lr_scheduler_gamma)
261
        return [optimizer], [scheduler]
262

263

264
class MyLightningCLI(LightningCLI):
265
    def add_arguments_to_parser(self, parser):
266
        parser.add_lightning_class_args(MilestonesFinetuning, "finetuning")
267
        parser.link_arguments("data.batch_size", "model.batch_size")
268
        parser.link_arguments("finetuning.milestones", "model.milestones")
269
        parser.link_arguments("finetuning.train_bn", "model.train_bn")
270
        parser.set_defaults({
271
            "trainer.max_epochs": 15,
272
            "trainer.enable_model_summary": False,
273
            "trainer.num_sanity_val_steps": 0,
274
        })
275

276

277
def cli_main():
278
    MyLightningCLI(TransferLearningModel, CatDogImageDataModule, seed_everything_default=1234)
279

280

281
if __name__ == "__main__":
282
    cli_lightning_logo()
283
    cli_main()
284

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.