pytorch-lightning
283 строки · 10.1 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Computer vision example on Transfer Learning. This computer vision example illustrates how one could fine-tune a
15pre-trained network (by default, a ResNet50 is used) using pytorch-lightning. For the sake of this example, the 'cats
16and dogs dataset' (~60MB, see `DATA_URL` below) and the proposed network (denoted by `TransferLearningModel`, see
17below) is trained for 15 epochs.
18
19The training consists of three stages.
20
21From epoch 0 to 4, the feature extractor (the pre-trained network) is frozen except
22maybe for the BatchNorm layers (depending on whether `train_bn = True`). The BatchNorm
23layers (if `train_bn = True`) and the parameters of the classifier are trained as a
24single parameters group with lr = 1e-2.
25
26From epoch 5 to 9, the last two layer groups of the pre-trained network are unfrozen
27and added to the optimizer as a new parameter group with lr = 1e-4 (while lr = 1e-3
28for the first parameter group in the optimizer).
29
30Eventually, from epoch 10, all the remaining layer groups of the pre-trained network
31are unfrozen and added to the optimizer as a third parameter group. From epoch 10,
32the parameters of the pre-trained network are trained with lr = 1e-5 while those of
33the classifier is trained with lr = 1e-4.
34
35Note:
36See: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
37
38To run:
39python computer_vision_fine_tuning.py fit
40
41"""
42
43import logging
44from pathlib import Path
45from typing import Union
46
47import torch
48import torch.nn.functional as F
49from lightning.pytorch import LightningDataModule, LightningModule, cli_lightning_logo
50from lightning.pytorch.callbacks.finetuning import BaseFinetuning
51from lightning.pytorch.cli import LightningCLI
52from lightning.pytorch.utilities import rank_zero_info
53from lightning.pytorch.utilities.model_helpers import get_torchvision_model
54from torch import nn, optim
55from torch.optim.lr_scheduler import MultiStepLR
56from torch.optim.optimizer import Optimizer
57from torch.utils.data import DataLoader
58from torchmetrics import Accuracy
59from torchvision import transforms
60from torchvision.datasets import ImageFolder
61from torchvision.datasets.utils import download_and_extract_archive
62
63log = logging.getLogger(__name__)
64DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"
65
66# --- Finetuning Callback ---
67
68
69class MilestonesFinetuning(BaseFinetuning):
70def __init__(self, milestones: tuple = (5, 10), train_bn: bool = False):
71super().__init__()
72self.milestones = milestones
73self.train_bn = train_bn
74
75def freeze_before_training(self, pl_module: LightningModule):
76self.freeze(modules=pl_module.feature_extractor, train_bn=self.train_bn)
77
78def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer):
79if epoch == self.milestones[0]:
80# unfreeze 5 last layers
81self.unfreeze_and_add_param_group(
82modules=pl_module.feature_extractor[-5:], optimizer=optimizer, train_bn=self.train_bn
83)
84
85elif epoch == self.milestones[1]:
86# unfreeze remaining layers
87self.unfreeze_and_add_param_group(
88modules=pl_module.feature_extractor[:-5], optimizer=optimizer, train_bn=self.train_bn
89)
90
91
92class CatDogImageDataModule(LightningDataModule):
93def __init__(self, dl_path: Union[str, Path] = "data", num_workers: int = 0, batch_size: int = 8):
94"""CatDogImageDataModule.
95
96Args:
97dl_path: root directory where to download the data
98num_workers: number of CPU workers
99batch_size: number of sample in a batch
100
101"""
102super().__init__()
103
104self._dl_path = dl_path
105self._num_workers = num_workers
106self._batch_size = batch_size
107
108def prepare_data(self):
109"""Download images and prepare images datasets."""
110download_and_extract_archive(url=DATA_URL, download_root=self._dl_path, remove_finished=True)
111
112@property
113def data_path(self):
114return Path(self._dl_path).joinpath("cats_and_dogs_filtered")
115
116@property
117def normalize_transform(self):
118return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
119
120@property
121def train_transform(self):
122return transforms.Compose([
123transforms.Resize((224, 224)),
124transforms.RandomHorizontalFlip(),
125transforms.ToTensor(),
126self.normalize_transform,
127])
128
129@property
130def valid_transform(self):
131return transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), self.normalize_transform])
132
133def create_dataset(self, root, transform):
134return ImageFolder(root=root, transform=transform)
135
136def __dataloader(self, train: bool):
137"""Train/validation loaders."""
138if train:
139dataset = self.create_dataset(self.data_path.joinpath("train"), self.train_transform)
140else:
141dataset = self.create_dataset(self.data_path.joinpath("validation"), self.valid_transform)
142return DataLoader(dataset=dataset, batch_size=self._batch_size, num_workers=self._num_workers, shuffle=train)
143
144def train_dataloader(self):
145log.info("Training data loaded.")
146return self.__dataloader(train=True)
147
148def val_dataloader(self):
149log.info("Validation data loaded.")
150return self.__dataloader(train=False)
151
152
153# --- PyTorch Lightning module ---
154
155
156class TransferLearningModel(LightningModule):
157def __init__(
158self,
159backbone: str = "resnet50",
160train_bn: bool = False,
161milestones: tuple = (2, 4),
162batch_size: int = 32,
163lr: float = 1e-3,
164lr_scheduler_gamma: float = 1e-1,
165num_workers: int = 6,
166**kwargs,
167) -> None:
168"""TransferLearningModel.
169
170Args:
171backbone: Name (as in ``torchvision.models``) of the feature extractor
172train_bn: Whether the BatchNorm layers should be trainable
173milestones: List of two epochs milestones
174lr: Initial learning rate
175lr_scheduler_gamma: Factor by which the learning rate is reduced at each milestone
176
177"""
178super().__init__()
179self.backbone = backbone
180self.train_bn = train_bn
181self.milestones = milestones
182self.batch_size = batch_size
183self.lr = lr
184self.lr_scheduler_gamma = lr_scheduler_gamma
185self.num_workers = num_workers
186
187self.__build_model()
188
189self.train_acc = Accuracy(task="binary")
190self.valid_acc = Accuracy(task="binary")
191self.save_hyperparameters()
192
193def __build_model(self):
194"""Define model layers & loss."""
195# 1. Load pre-trained network:
196backbone = get_torchvision_model(self.backbone, weights="DEFAULT")
197
198_layers = list(backbone.children())[:-1]
199self.feature_extractor = nn.Sequential(*_layers)
200
201# 2. Classifier:
202_fc_layers = [nn.Linear(2048, 256), nn.ReLU(), nn.Linear(256, 32), nn.Linear(32, 1)]
203self.fc = nn.Sequential(*_fc_layers)
204
205# 3. Loss:
206self.loss_func = F.binary_cross_entropy_with_logits
207
208def forward(self, x):
209"""Forward pass.
210
211Returns logits.
212
213"""
214# 1. Feature extraction:
215x = self.feature_extractor(x)
216x = x.squeeze(-1).squeeze(-1)
217
218# 2. Classifier (returns logits):
219return self.fc(x)
220
221def loss(self, logits, labels):
222return self.loss_func(input=logits, target=labels)
223
224def training_step(self, batch, batch_idx):
225# 1. Forward pass:
226x, y = batch
227y_logits = self.forward(x)
228y_scores = torch.sigmoid(y_logits)
229y_true = y.view((-1, 1)).type_as(x)
230
231# 2. Compute loss
232train_loss = self.loss(y_logits, y_true)
233
234# 3. Compute accuracy:
235self.log("train_acc", self.train_acc(y_scores, y_true.int()), prog_bar=True)
236
237return train_loss
238
239def validation_step(self, batch, batch_idx):
240# 1. Forward pass:
241x, y = batch
242y_logits = self.forward(x)
243y_scores = torch.sigmoid(y_logits)
244y_true = y.view((-1, 1)).type_as(x)
245
246# 2. Compute loss
247self.log("val_loss", self.loss(y_logits, y_true), prog_bar=True)
248
249# 3. Compute accuracy:
250self.log("val_acc", self.valid_acc(y_scores, y_true.int()), prog_bar=True)
251
252def configure_optimizers(self):
253parameters = list(self.parameters())
254trainable_parameters = list(filter(lambda p: p.requires_grad, parameters))
255rank_zero_info(
256f"The model will start training with only {len(trainable_parameters)} "
257f"trainable parameters out of {len(parameters)}."
258)
259optimizer = optim.Adam(trainable_parameters, lr=self.lr)
260scheduler = MultiStepLR(optimizer, milestones=self.milestones, gamma=self.lr_scheduler_gamma)
261return [optimizer], [scheduler]
262
263
264class MyLightningCLI(LightningCLI):
265def add_arguments_to_parser(self, parser):
266parser.add_lightning_class_args(MilestonesFinetuning, "finetuning")
267parser.link_arguments("data.batch_size", "model.batch_size")
268parser.link_arguments("finetuning.milestones", "model.milestones")
269parser.link_arguments("finetuning.train_bn", "model.train_bn")
270parser.set_defaults({
271"trainer.max_epochs": 15,
272"trainer.enable_model_summary": False,
273"trainer.num_sanity_val_steps": 0,
274})
275
276
277def cli_main():
278MyLightningCLI(TransferLearningModel, CatDogImageDataModule, seed_everything_default=1234)
279
280
281if __name__ == "__main__":
282cli_lightning_logo()
283cli_main()
284