skypilot

Форк
0
/
tune_ptl_example.py 
137 строк · 4.9 Кб
1
### Source: https://docs.ray.io/en/latest/tune/examples/mnist_ptl_mini.html
2
import math
3
import os
4

5
from filelock import FileLock
6
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
7
import pytorch_lightning as pl
8
from ray import tune
9
from ray.tune.integration.pytorch_lightning import TuneReportCallback
10
import torch
11
from torch.nn import functional as F
12

13

14
class LightningMNISTClassifier(pl.LightningModule):
15

16
    def __init__(self, config, data_dir=None):
17
        super(LightningMNISTClassifier, self).__init__()
18

19
        self.data_dir = data_dir or os.getcwd()
20
        self.lr = config["lr"]
21
        layer_1, layer_2 = config["layer_1"], config["layer_2"]
22
        self.batch_size = config["batch_size"]
23

24
        # mnist images are (1, 28, 28) (channels, width, height)
25
        self.layer_1 = torch.nn.Linear(28 * 28, layer_1)
26
        self.layer_2 = torch.nn.Linear(layer_1, layer_2)
27
        self.layer_3 = torch.nn.Linear(layer_2, 10)
28
        self.accuracy = pl.metrics.Accuracy()
29

30
    def forward(self, x):
31
        batch_size, channels, width, height = x.size()
32
        x = x.view(batch_size, -1)
33
        x = self.layer_1(x)
34
        x = torch.relu(x)
35
        x = self.layer_2(x)
36
        x = torch.relu(x)
37
        x = self.layer_3(x)
38
        x = torch.log_softmax(x, dim=1)
39
        return x
40

41
    def configure_optimizers(self):
42
        return torch.optim.Adam(self.parameters(), lr=self.lr)
43

44
    def training_step(self, train_batch, batch_idx):
45
        x, y = train_batch
46
        logits = self.forward(x)
47
        loss = F.nll_loss(logits, y)
48
        acc = self.accuracy(logits, y)
49
        self.log("ptl/train_loss", loss)
50
        self.log("ptl/train_accuracy", acc)
51
        return loss
52

53
    def validation_step(self, val_batch, batch_idx):
54
        x, y = val_batch
55
        logits = self.forward(x)
56
        loss = F.nll_loss(logits, y)
57
        acc = self.accuracy(logits, y)
58
        return {"val_loss": loss, "val_accuracy": acc}
59

60
    def validation_epoch_end(self, outputs):
61
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
62
        avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
63
        self.log("ptl/val_loss", avg_loss)
64
        self.log("ptl/val_accuracy", avg_acc)
65

66

67
def train_mnist_tune(config, num_epochs=10, num_gpus=0):
68
    data_dir = os.path.abspath("./data")
69
    model = LightningMNISTClassifier(config, data_dir)
70
    with FileLock(os.path.expanduser("~/.data.lock")):
71
        dm = MNISTDataModule(data_dir=data_dir,
72
                             num_workers=1,
73
                             batch_size=config["batch_size"])
74
    metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
75
    trainer = pl.Trainer(
76
        max_epochs=num_epochs,
77
        # If fractional GPUs passed in, convert to int.
78
        gpus=math.ceil(num_gpus),
79
        progress_bar_refresh_rate=0,
80
        callbacks=[TuneReportCallback(metrics, on="validation_end")])
81
    trainer.fit(model, dm)
82

83

84
def tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0):
85
    config = {
86
        "layer_1": tune.choice([32, 64, 128]),
87
        "layer_2": tune.choice([64, 128, 256]),
88
        "lr": tune.loguniform(1e-4, 1e-1),
89
        "batch_size": tune.choice([32, 64, 128]),
90
    }
91

92
    trainable = tune.with_parameters(train_mnist_tune,
93
                                     num_epochs=num_epochs,
94
                                     num_gpus=gpus_per_trial)
95
    analysis = tune.run(trainable,
96
                        resources_per_trial={
97
                            "cpu": 1,
98
                            "gpu": gpus_per_trial
99
                        },
100
                        metric="loss",
101
                        mode="min",
102
                        config=config,
103
                        num_samples=num_samples,
104
                        name="tune_mnist")
105

106
    print("Best hyperparameters found were: ", analysis.best_config)
107

108

109
if __name__ == "__main__":
110
    import argparse
111

112
    parser = argparse.ArgumentParser()
113
    parser.add_argument("--smoke-test",
114
                        action="store_true",
115
                        help="Finish quickly for testing")
116
    parser.add_argument("--server-address",
117
                        type=str,
118
                        default="auto",
119
                        required=False,
120
                        help="The address of server to connect to if using "
121
                        "Ray Client.")
122
    args, _ = parser.parse_known_args()
123

124
    if args.smoke_test:
125
        tune_mnist(num_samples=1, num_epochs=1, gpus_per_trial=0)
126
    else:
127
        if args.server_address:
128
            import ray
129
            ray.init(args.server_address)
130

131
            print('cluster_resources:', ray.cluster_resources())
132
            print('available_resources:', ray.available_resources())
133
            print('live nodes:', ray.state.node_ids())
134
            resources = ray.cluster_resources()
135
            assert resources["accelerator_type:V100"] > 1, resources
136

137
        tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0)
138

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.