TransformerEngine

Форк
0
/
test_single_gpu_mnist.py 
274 строки · 8.8 Кб
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
#
3
# See LICENSE for license information.
4
"""MNIST example of Transformer Engine Paddle"""
5

6
import argparse
7
import os
8
import unittest
9

10
import paddle
11
from paddle import nn
12
import paddle.nn.functional as F
13

14
from paddle.vision.transforms import Normalize
15
from paddle.io import DataLoader
16
from paddle.vision.datasets import MNIST
17
from paddle.metric import Accuracy
18

19
import transformer_engine.paddle as te
20
from transformer_engine.paddle.fp8 import is_fp8_available
21

22

23
class Net(nn.Layer):
24
    """Simple network used to train on MNIST"""
25

26
    def __init__(self, use_te=False):
27
        super().__init__()
28
        self.conv1 = nn.Conv2D(1, 32, 3, 1)
29
        self.conv2 = nn.Conv2D(32, 64, 3, 1)
30
        self.dropout1 = nn.Dropout(0.25)
31
        self.dropout2 = nn.Dropout(0.5)
32
        if use_te:
33
            self.fc1 = te.Linear(9216, 128)
34
            self.fc2 = te.Linear(128, 16)
35
        else:
36
            self.fc1 = nn.Linear(9216, 128)
37
            self.fc2 = nn.Linear(128, 16)
38
        self.fc3 = nn.Linear(16, 10)
39

40
    def forward(self, x):
41
        """FWD"""
42
        x = self.conv1(x)
43
        x = F.relu(x)
44
        x = self.conv2(x)
45
        x = F.relu(x)
46
        x = F.max_pool2d(x, 2)
47
        x = self.dropout1(x)
48
        x = paddle.flatten(x, 1)
49
        x = self.fc1(x)
50
        x = F.relu(x)
51
        x = self.dropout2(x)
52
        x = self.fc2(x)
53
        x = self.fc3(x)
54
        return x
55

56

57
def train(args, model, train_loader, optimizer, epoch, use_fp8):
58
    """Training function."""
59
    model.train()
60
    losses = []
61
    for batch_id, (data, labels) in enumerate(train_loader):
62
        with paddle.amp.auto_cast(dtype='bfloat16', level='O2'):    # pylint: disable=not-context-manager
63
            with te.fp8_autocast(enabled=use_fp8):
64
                outputs = model(data)
65
            loss = F.cross_entropy(outputs, labels)
66
            losses.append(loss.item())
67

68
        loss.backward()
69
        optimizer.step()
70
        optimizer.clear_gradients()
71

72
        if batch_id % args.log_interval == 0:
73
            print(f"Train Epoch: {epoch} "
74
                  f"[{batch_id * len(data)}/{len(train_loader.dataset)} "
75
                  f"({100. * batch_id / len(train_loader):.0f}%)]\t"
76
                  f"Loss: {loss.item():.6f}")
77
            if args.dry_run:
78
                return loss.item()
79
    avg_loss = sum(losses) / len(losses)
80
    print(f"Train Epoch: {epoch}, Average Loss: {avg_loss}")
81
    return avg_loss
82

83

84
def evaluate(model, test_loader, epoch, use_fp8):
85
    """Testing function."""
86
    model.eval()
87
    metric = Accuracy()
88
    metric.reset()
89

90
    with paddle.no_grad():
91
        for data, labels in test_loader:
92
            with paddle.amp.auto_cast(dtype='bfloat16', level='O2'):    # pylint: disable=not-context-manager
93
                with te.fp8_autocast(enabled=use_fp8):
94
                    outputs = model(data)
95
                acc = metric.compute(outputs, labels)
96
            metric.update(acc)
97
    print(f"Epoch[{epoch}] - accuracy: {metric.accumulate():.6f}")
98
    return metric.accumulate()
99

100

101
def calibrate(model, test_loader):
102
    """Calibration function."""
103
    model.eval()
104

105
    with paddle.no_grad():
106
        for data, _ in test_loader:
107
            with paddle.amp.auto_cast(dtype='bfloat16', level='O2'):    # pylint: disable=not-context-manager
108
                with te.fp8_autocast(enabled=False, calibrating=True):
109
                    _ = model(data)
110

111

112
def mnist_parser(args):
113
    """Parse training settings"""
114
    parser = argparse.ArgumentParser(description="Paddle MNIST Example")
115
    parser.add_argument(
116
        "--batch-size",
117
        type=int,
118
        default=64,
119
        metavar="N",
120
        help="input batch size for training (default: 64)",
121
    )
122
    parser.add_argument(
123
        "--test-batch-size",
124
        type=int,
125
        default=1000,
126
        metavar="N",
127
        help="input batch size for testing (default: 1000)",
128
    )
129
    parser.add_argument(
130
        "--epochs",
131
        type=int,
132
        default=14,
133
        metavar="N",
134
        help="number of epochs to train (default: 14)",
135
    )
136
    parser.add_argument(
137
        "--lr",
138
        type=float,
139
        default=0.001,
140
        metavar="LR",
141
        help="learning rate (default: 0.001)",
142
    )
143
    parser.add_argument(
144
        "--dry-run",
145
        action="store_true",
146
        default=False,
147
        help="quickly check a single pass",
148
    )
149
    parser.add_argument(
150
        "--save-model",
151
        action="store_true",
152
        default=False,
153
        help="For Saving the current Model",
154
    )
155
    parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
156
    parser.add_argument(
157
        "--log-interval",
158
        type=int,
159
        default=10,
160
        metavar="N",
161
        help="how many batches to wait before logging training status",
162
    )
163
    parser.add_argument("--use-fp8",
164
                        action="store_true",
165
                        default=False,
166
                        help="Use FP8 for inference and training without recalibration. " \
167
                             "It also enables Transformer Engine implicitly.")
168
    parser.add_argument("--use-fp8-infer",
169
                        action="store_true",
170
                        default=False,
171
                        help="Use FP8 for inference only. If not using FP8 for training, "
172
                        "calibration is performed for FP8 infernece.")
173
    parser.add_argument("--use-te",
174
                        action="store_true",
175
                        default=False,
176
                        help="Use Transformer Engine")
177
    args = parser.parse_args(args)
178
    return args
179

180

181
def train_and_evaluate(args):
182
    """Execute model training and evaluation loop."""
183
    print(args)
184

185
    paddle.seed(args.seed)
186

187
    # Load MNIST dataset
188
    transform = Normalize(mean=[127.5], std=[127.5], data_format='CHW')
189
    train_dataset = MNIST(mode='train', transform=transform)
190
    val_dataset = MNIST(mode='test', transform=transform)
191

192
    # Define data loaders
193
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
194
    val_loader = DataLoader(val_dataset, batch_size=args.test_batch_size)
195

196
    # Define model and optimizer
197
    model = Net(use_te=args.use_te)
198
    optimizer = paddle.optimizer.Adam(learning_rate=args.lr, parameters=model.parameters())
199

200
    # Cast model to BF16
201
    model = paddle.amp.decorate(models=model, level='O2', dtype='bfloat16')
202

203
    for epoch in range(1, args.epochs + 1):
204
        loss = train(args, model, train_loader, optimizer, epoch, args.use_fp8)
205
        acc = evaluate(model, val_loader, epoch, args.use_fp8)
206

207
    if args.use_fp8_infer and not args.use_fp8:
208
        calibrate(model, val_loader)
209

210
    if args.save_model or args.use_fp8_infer:
211
        paddle.save(model.state_dict(), "mnist_cnn.pdparams")
212
        print('Eval with reloaded checkpoint : fp8=' + str(args.use_fp8))
213
        weights = paddle.load("mnist_cnn.pdparams")
214
        model.set_state_dict(weights)
215
        acc = evaluate(model, val_loader, 0, args.use_fp8)
216

217
    return loss, acc
218

219

220
class TestMNIST(unittest.TestCase):
221
    """MNIST unittests"""
222

223
    gpu_has_fp8, reason = is_fp8_available()
224

225
    @classmethod
226
    def setUpClass(cls):
227
        """Run MNIST without Transformer Engine"""
228
        cls.args = mnist_parser(["--epochs", "5"])
229

230
    @staticmethod
231
    def verify(actual):
232
        """Check If loss and accuracy match target"""
233
        desired_traing_loss = 0.1
234
        desired_test_accuracy = 0.98
235
        assert actual[0] < desired_traing_loss
236
        assert actual[1] > desired_test_accuracy
237

238
    @unittest.skipIf(paddle.device.cuda.get_device_capability() < (8, 0),
239
                     "BF16 MNIST example requires Ampere+ GPU")
240
    def test_te_bf16(self):
241
        """Test Transformer Engine with BF16"""
242
        self.args.use_te = True
243
        self.args.use_fp8 = False
244
        self.args.save_model = True
245
        actual = train_and_evaluate(self.args)
246
        if os.path.exists("mnist_cnn.pdparams"):
247
            os.remove("mnist_cnn.pdparams")
248
        self.verify(actual)
249

250
    @unittest.skipIf(not gpu_has_fp8, reason)
251
    def test_te_fp8(self):
252
        """Test Transformer Engine with FP8"""
253
        self.args.use_te = True
254
        self.args.use_fp8 = True
255
        self.args.save_model = True
256
        actual = train_and_evaluate(self.args)
257
        if os.path.exists("mnist_cnn.pdparams"):
258
            os.remove("mnist_cnn.pdparams")
259
        self.verify(actual)
260

261
    @unittest.skipIf(not gpu_has_fp8, reason)
262
    def test_te_fp8_calibration(self):
263
        """Test Transformer Engine with FP8 calibration"""
264
        self.args.use_te = True
265
        self.args.use_fp8 = False
266
        self.args.use_fp8_infer = True
267
        actual = train_and_evaluate(self.args)
268
        if os.path.exists("mnist_cnn.pdparams"):
269
            os.remove("mnist_cnn.pdparams")
270
        self.verify(actual)
271

272

273
if __name__ == "__main__":
274
    train_and_evaluate(mnist_parser(None))
275

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.