pytorch-lightning
206 строк · 7.0 Кб
1"""Demo of a simple transformer language model.
2
3Code is adapted from the PyTorch examples at
4https://github.com/pytorch/examples/blob/main/word_language_model
5
6"""
7
8import math9import os10from pathlib import Path11from typing import Dict, List, Optional, Tuple12
13import torch14import torch.nn as nn15import torch.nn.functional as F16from lightning_utilities.core.imports import RequirementCache17from torch import Tensor18from torch.nn.modules import MultiheadAttention19from torch.utils.data import DataLoader, Dataset20
21from lightning.pytorch import LightningModule22
23_REQUESTS_AVAILABLE = RequirementCache("requests")24
25
26if hasattr(MultiheadAttention, "_reset_parameters") and not hasattr(MultiheadAttention, "reset_parameters"):27# See https://github.com/pytorch/pytorch/issues/10790928MultiheadAttention.reset_parameters = MultiheadAttention._reset_parameters29
30
31class Transformer(nn.Module):32def __init__(33self,34vocab_size: int = 33278, # default for WikiText235ninp: int = 200,36nhead: int = 2,37nhid: int = 200,38nlayers: int = 2,39dropout: float = 0.2,40) -> None:41super().__init__()42self.pos_encoder = PositionalEncoding(ninp, dropout)43self.embedding = nn.Embedding(vocab_size, ninp)44self.transformer = nn.Transformer(45d_model=ninp,46nhead=nhead,47num_encoder_layers=nlayers,48num_decoder_layers=nlayers,49dim_feedforward=nhid,50dropout=dropout,51batch_first=True,52)53self.decoder = nn.Linear(ninp, vocab_size)54
55self.ninp = ninp56self.vocab_size = vocab_size57self.src_mask = None58
59def forward(self, inputs: Tensor, target: Tensor, mask: Optional[Tensor] = None) -> Tensor:60_, t = inputs.shape61
62# we assume target is already shifted w.r.t. inputs63if mask is None:64mask = torch.tril(torch.ones(t, t, device=inputs.device)) == 165mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))66
67src = self.pos_encoder(self.embedding(inputs) * math.sqrt(self.ninp))68target = self.pos_encoder(self.embedding(target) * math.sqrt(self.ninp))69output = self.transformer(src, target, tgt_mask=mask)70output = self.decoder(output)71output = F.log_softmax(output, dim=-1)72output = output.view(-1, self.vocab_size)73return output74
75
76class PositionalEncoding(nn.Module):77def __init__(self, dim: int, dropout: float = 0.1, max_len: int = 5000) -> None:78super().__init__()79self.dropout = nn.Dropout(p=dropout)80self.dim = dim81self.max_len = max_len82self.pe: Optional[Tensor] = None83
84def forward(self, x: Tensor) -> Tensor:85if self.pe is None:86# 1) can't use buffer, see https://github.com/pytorch/pytorch/issues/6840787# 2) can't use parameter becauses pe gets sliced and DDP requires all params to participate in forward88# 3) can't make it a `requires_grad=False` parameter because FSDP in PyTorch < 2.1 needs all params to89# require grad90self.pe = self._init_pos_encoding(device=x.device)91
92x + self.pe[: x.size(0), :]93return self.dropout(x)94
95def _init_pos_encoding(self, device: torch.device) -> Tensor:96pe = torch.zeros(self.max_len, self.dim, device=device)97position = torch.arange(0, self.max_len, dtype=torch.float, device=device).unsqueeze(1)98div_term = torch.exp(torch.arange(0, self.dim, 2, device=device).float() * (-math.log(10000.0) / self.dim))99pe[:, 0::2] = torch.sin(position * div_term)100pe[:, 1::2] = torch.cos(position * div_term)101pe = pe.unsqueeze(0).transpose(0, 1)102return pe103
104
105class WikiText2(Dataset):106"""Mini version of WikiText2."""107
108def __init__(self, data_dir: Path = Path("./data"), block_size: int = 35, download: bool = True) -> None:109super().__init__()110self.path = data_dir / "wikitext-2.txt"111if download:112self.download(self.path)113self.data, self.dictionary = tokenize(self.path)114self.block_size = block_size115
116@property117def vocab_size(self) -> int:118return len(self.dictionary)119
120def __len__(self) -> int:121return len(self.data) // self.block_size - 1122
123def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:124start = index * self.block_size125end = start + self.block_size126inputs = self.data[start:end]127target = self.data[(start + 1) : (end + 1)]128return inputs, target129
130@staticmethod131def download(destination: Path) -> None:132if not _REQUESTS_AVAILABLE:133raise ModuleNotFoundError(str(_REQUESTS_AVAILABLE))134
135import requests136
137os.makedirs(destination.parent, exist_ok=True)138url = "https://raw.githubusercontent.com/pytorch/examples/main/word_language_model/data/wikitext-2/train.txt"139if os.path.exists(destination):140return141with open(destination, "w") as f:142f.write(requests.get(url).text)143
144
145class Dictionary:146def __init__(self) -> None:147self.word2idx: Dict[str, int] = {}148self.idx2word: List[str] = []149
150def add_word(self, word: str) -> int:151if word not in self.word2idx:152self.idx2word.append(word)153self.word2idx[word] = len(self.idx2word) - 1154return self.word2idx[word]155
156def __len__(self) -> int:157return len(self.idx2word)158
159
160def tokenize(path: Path) -> Tuple[Tensor, Dictionary]:161dictionary = Dictionary()162
163assert os.path.exists(path)164# Add words to the dictionary165with open(path, encoding="utf8") as f:166for line in f:167words = line.split() + ["<eos>"]168for word in words:169dictionary.add_word(word)170
171# Tokenize file content172with open(path, encoding="utf8") as f:173idss: List[Tensor] = []174for line in f:175words = line.split() + ["<eos>"]176ids: List[int] = []177for word in words:178ids.append(dictionary.word2idx[word])179idss.append(torch.tensor(ids).type(torch.int64))180
181return torch.cat(idss), dictionary182
183
184class LightningTransformer(LightningModule):185def __init__(self, vocab_size: int = 33278) -> None:186super().__init__()187self.model = Transformer(vocab_size=vocab_size)188
189def forward(self, inputs: Tensor, target: Tensor) -> Tensor:190return self.model(inputs, target)191
192def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:193inputs, target = batch194output = self(inputs, target)195loss = torch.nn.functional.nll_loss(output, target.view(-1))196return loss197
198def configure_optimizers(self) -> torch.optim.Optimizer:199return torch.optim.SGD(self.model.parameters(), lr=0.1)200
201def prepare_data(self) -> None:202WikiText2(download=True)203
204def train_dataloader(self) -> DataLoader:205dataset = WikiText2()206return DataLoader(dataset)207