pytorch-lightning
123 строки · 4.1 Кб
1import base642from dataclasses import dataclass3from io import BytesIO4from os import path5from typing import Dict, Optional6
7import numpy as np8import torch9import torchvision10import torchvision.transforms as T11from lightning.pytorch import LightningDataModule, LightningModule, cli_lightning_logo12from lightning.pytorch.cli import LightningCLI13from lightning.pytorch.serve import ServableModule, ServableModuleValidator14from lightning.pytorch.utilities.model_helpers import get_torchvision_model15from PIL import Image as PILImage16
17DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")18
19
20class LitModule(LightningModule):21def __init__(self, name: str = "resnet18"):22super().__init__()23self.model = get_torchvision_model(name, weights="DEFAULT")24self.model.fc = torch.nn.Linear(self.model.fc.in_features, 10)25self.criterion = torch.nn.CrossEntropyLoss()26
27def training_step(self, batch, batch_idx):28inputs, labels = batch29outputs = self.model(inputs)30loss = self.criterion(outputs, labels)31self.log("train_loss", loss)32return loss33
34def validation_step(self, batch, batch_idx):35inputs, labels = batch36outputs = self.model(inputs)37loss = self.criterion(outputs, labels)38self.log("val_loss", loss)39
40def configure_optimizers(self):41return torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)42
43
44class CIFAR10DataModule(LightningDataModule):45transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])46
47def train_dataloader(self, *args, **kwargs):48trainset = torchvision.datasets.CIFAR10(root=DATASETS_PATH, train=True, download=True, transform=self.transform)49return torch.utils.data.DataLoader(trainset, batch_size=2, shuffle=True, num_workers=0)50
51def val_dataloader(self, *args, **kwargs):52valset = torchvision.datasets.CIFAR10(root=DATASETS_PATH, train=False, download=True, transform=self.transform)53return torch.utils.data.DataLoader(valset, batch_size=2, shuffle=True, num_workers=0)54
55
56@dataclass(unsafe_hash=True)57class Image:58height: Optional[int] = None59width: Optional[int] = None60extension: str = "JPEG"61mode: str = "RGB"62channel_first: bool = False63
64def deserialize(self, data: str) -> torch.Tensor:65encoded_with_padding = (data + "===").encode("UTF-8")66img = base64.b64decode(encoded_with_padding)67buffer = BytesIO(img)68img = PILImage.open(buffer, mode="r")69if self.height and self.width:70img = img.resize((self.width, self.height))71arr = np.array(img)72return T.ToTensor()(arr).unsqueeze(0)73
74
75class Top1:76def serialize(self, tensor: torch.Tensor) -> int:77return torch.nn.functional.softmax(tensor).argmax().item()78
79
80class ProductionReadyModel(LitModule, ServableModule):81def configure_payload(self):82# 1: Access the train dataloader and load a single sample.83image, _ = self.trainer.train_dataloader.dataset[0]84
85# 2: Convert the image into a PIL Image to bytes and encode it with base6486pil_image = T.ToPILImage()(image)87buffered = BytesIO()88pil_image.save(buffered, format="JPEG")89img_str = base64.b64encode(buffered.getvalue()).decode("UTF-8")90
91return {"body": {"x": img_str}}92
93def configure_serialization(self):94return {"x": Image(224, 224).deserialize}, {"output": Top1().serialize}95
96def serve_step(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:97return {"output": self.model(x)}98
99def configure_response(self):100return {"output": 7}101
102
103def cli_main():104cli = LightningCLI(105ProductionReadyModel,106CIFAR10DataModule,107seed_everything_default=42,108save_config_kwargs={"overwrite": True},109run=False,110trainer_defaults={111"accelerator": "cpu",112"callbacks": [ServableModuleValidator()],113"max_epochs": 1,114"limit_train_batches": 5,115"limit_val_batches": 5,116},117)118cli.trainer.fit(cli.model, cli.datamodule)119
120
121if __name__ == "__main__":122cli_lightning_logo()123cli_main()124