skypilot
137 строк · 4.9 Кб
1### Source: https://docs.ray.io/en/latest/tune/examples/mnist_ptl_mini.html
2import math
3import os
4
5from filelock import FileLock
6from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
7import pytorch_lightning as pl
8from ray import tune
9from ray.tune.integration.pytorch_lightning import TuneReportCallback
10import torch
11from torch.nn import functional as F
12
13
14class LightningMNISTClassifier(pl.LightningModule):
15
16def __init__(self, config, data_dir=None):
17super(LightningMNISTClassifier, self).__init__()
18
19self.data_dir = data_dir or os.getcwd()
20self.lr = config["lr"]
21layer_1, layer_2 = config["layer_1"], config["layer_2"]
22self.batch_size = config["batch_size"]
23
24# mnist images are (1, 28, 28) (channels, width, height)
25self.layer_1 = torch.nn.Linear(28 * 28, layer_1)
26self.layer_2 = torch.nn.Linear(layer_1, layer_2)
27self.layer_3 = torch.nn.Linear(layer_2, 10)
28self.accuracy = pl.metrics.Accuracy()
29
30def forward(self, x):
31batch_size, channels, width, height = x.size()
32x = x.view(batch_size, -1)
33x = self.layer_1(x)
34x = torch.relu(x)
35x = self.layer_2(x)
36x = torch.relu(x)
37x = self.layer_3(x)
38x = torch.log_softmax(x, dim=1)
39return x
40
41def configure_optimizers(self):
42return torch.optim.Adam(self.parameters(), lr=self.lr)
43
44def training_step(self, train_batch, batch_idx):
45x, y = train_batch
46logits = self.forward(x)
47loss = F.nll_loss(logits, y)
48acc = self.accuracy(logits, y)
49self.log("ptl/train_loss", loss)
50self.log("ptl/train_accuracy", acc)
51return loss
52
53def validation_step(self, val_batch, batch_idx):
54x, y = val_batch
55logits = self.forward(x)
56loss = F.nll_loss(logits, y)
57acc = self.accuracy(logits, y)
58return {"val_loss": loss, "val_accuracy": acc}
59
60def validation_epoch_end(self, outputs):
61avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
62avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
63self.log("ptl/val_loss", avg_loss)
64self.log("ptl/val_accuracy", avg_acc)
65
66
67def train_mnist_tune(config, num_epochs=10, num_gpus=0):
68data_dir = os.path.abspath("./data")
69model = LightningMNISTClassifier(config, data_dir)
70with FileLock(os.path.expanduser("~/.data.lock")):
71dm = MNISTDataModule(data_dir=data_dir,
72num_workers=1,
73batch_size=config["batch_size"])
74metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
75trainer = pl.Trainer(
76max_epochs=num_epochs,
77# If fractional GPUs passed in, convert to int.
78gpus=math.ceil(num_gpus),
79progress_bar_refresh_rate=0,
80callbacks=[TuneReportCallback(metrics, on="validation_end")])
81trainer.fit(model, dm)
82
83
84def tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0):
85config = {
86"layer_1": tune.choice([32, 64, 128]),
87"layer_2": tune.choice([64, 128, 256]),
88"lr": tune.loguniform(1e-4, 1e-1),
89"batch_size": tune.choice([32, 64, 128]),
90}
91
92trainable = tune.with_parameters(train_mnist_tune,
93num_epochs=num_epochs,
94num_gpus=gpus_per_trial)
95analysis = tune.run(trainable,
96resources_per_trial={
97"cpu": 1,
98"gpu": gpus_per_trial
99},
100metric="loss",
101mode="min",
102config=config,
103num_samples=num_samples,
104name="tune_mnist")
105
106print("Best hyperparameters found were: ", analysis.best_config)
107
108
109if __name__ == "__main__":
110import argparse
111
112parser = argparse.ArgumentParser()
113parser.add_argument("--smoke-test",
114action="store_true",
115help="Finish quickly for testing")
116parser.add_argument("--server-address",
117type=str,
118default="auto",
119required=False,
120help="The address of server to connect to if using "
121"Ray Client.")
122args, _ = parser.parse_known_args()
123
124if args.smoke_test:
125tune_mnist(num_samples=1, num_epochs=1, gpus_per_trial=0)
126else:
127if args.server_address:
128import ray
129ray.init(args.server_address)
130
131print('cluster_resources:', ray.cluster_resources())
132print('available_resources:', ray.available_resources())
133print('live nodes:', ray.state.node_ids())
134resources = ray.cluster_resources()
135assert resources["accelerator_type:V100"] > 1, resources
136
137tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0)
138