pytorch-lightning

Форк
0
123 строки · 4.1 Кб
1
import base64
2
from dataclasses import dataclass
3
from io import BytesIO
4
from os import path
5
from typing import Dict, Optional
6

7
import numpy as np
8
import torch
9
import torchvision
10
import torchvision.transforms as T
11
from lightning.pytorch import LightningDataModule, LightningModule, cli_lightning_logo
12
from lightning.pytorch.cli import LightningCLI
13
from lightning.pytorch.serve import ServableModule, ServableModuleValidator
14
from lightning.pytorch.utilities.model_helpers import get_torchvision_model
15
from PIL import Image as PILImage
16

17
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
18

19

20
class LitModule(LightningModule):
21
    def __init__(self, name: str = "resnet18"):
22
        super().__init__()
23
        self.model = get_torchvision_model(name, weights="DEFAULT")
24
        self.model.fc = torch.nn.Linear(self.model.fc.in_features, 10)
25
        self.criterion = torch.nn.CrossEntropyLoss()
26

27
    def training_step(self, batch, batch_idx):
28
        inputs, labels = batch
29
        outputs = self.model(inputs)
30
        loss = self.criterion(outputs, labels)
31
        self.log("train_loss", loss)
32
        return loss
33

34
    def validation_step(self, batch, batch_idx):
35
        inputs, labels = batch
36
        outputs = self.model(inputs)
37
        loss = self.criterion(outputs, labels)
38
        self.log("val_loss", loss)
39

40
    def configure_optimizers(self):
41
        return torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
42

43

44
class CIFAR10DataModule(LightningDataModule):
45
    transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
46

47
    def train_dataloader(self, *args, **kwargs):
48
        trainset = torchvision.datasets.CIFAR10(root=DATASETS_PATH, train=True, download=True, transform=self.transform)
49
        return torch.utils.data.DataLoader(trainset, batch_size=2, shuffle=True, num_workers=0)
50

51
    def val_dataloader(self, *args, **kwargs):
52
        valset = torchvision.datasets.CIFAR10(root=DATASETS_PATH, train=False, download=True, transform=self.transform)
53
        return torch.utils.data.DataLoader(valset, batch_size=2, shuffle=True, num_workers=0)
54

55

56
@dataclass(unsafe_hash=True)
57
class Image:
58
    height: Optional[int] = None
59
    width: Optional[int] = None
60
    extension: str = "JPEG"
61
    mode: str = "RGB"
62
    channel_first: bool = False
63

64
    def deserialize(self, data: str) -> torch.Tensor:
65
        encoded_with_padding = (data + "===").encode("UTF-8")
66
        img = base64.b64decode(encoded_with_padding)
67
        buffer = BytesIO(img)
68
        img = PILImage.open(buffer, mode="r")
69
        if self.height and self.width:
70
            img = img.resize((self.width, self.height))
71
        arr = np.array(img)
72
        return T.ToTensor()(arr).unsqueeze(0)
73

74

75
class Top1:
76
    def serialize(self, tensor: torch.Tensor) -> int:
77
        return torch.nn.functional.softmax(tensor).argmax().item()
78

79

80
class ProductionReadyModel(LitModule, ServableModule):
81
    def configure_payload(self):
82
        # 1: Access the train dataloader and load a single sample.
83
        image, _ = self.trainer.train_dataloader.dataset[0]
84

85
        # 2: Convert the image into a PIL Image to bytes and encode it with base64
86
        pil_image = T.ToPILImage()(image)
87
        buffered = BytesIO()
88
        pil_image.save(buffered, format="JPEG")
89
        img_str = base64.b64encode(buffered.getvalue()).decode("UTF-8")
90

91
        return {"body": {"x": img_str}}
92

93
    def configure_serialization(self):
94
        return {"x": Image(224, 224).deserialize}, {"output": Top1().serialize}
95

96
    def serve_step(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
97
        return {"output": self.model(x)}
98

99
    def configure_response(self):
100
        return {"output": 7}
101

102

103
def cli_main():
104
    cli = LightningCLI(
105
        ProductionReadyModel,
106
        CIFAR10DataModule,
107
        seed_everything_default=42,
108
        save_config_kwargs={"overwrite": True},
109
        run=False,
110
        trainer_defaults={
111
            "accelerator": "cpu",
112
            "callbacks": [ServableModuleValidator()],
113
            "max_epochs": 1,
114
            "limit_train_batches": 5,
115
            "limit_val_batches": 5,
116
        },
117
    )
118
    cli.trainer.fit(cli.model, cli.datamodule)
119

120

121
if __name__ == "__main__":
122
    cli_lightning_logo()
123
    cli_main()
124

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.