TransformerEngine
274 строки · 8.8 Кб
1# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2#
3# See LICENSE for license information.
4"""MNIST example of Transformer Engine Paddle"""
5
6import argparse
7import os
8import unittest
9
10import paddle
11from paddle import nn
12import paddle.nn.functional as F
13
14from paddle.vision.transforms import Normalize
15from paddle.io import DataLoader
16from paddle.vision.datasets import MNIST
17from paddle.metric import Accuracy
18
19import transformer_engine.paddle as te
20from transformer_engine.paddle.fp8 import is_fp8_available
21
22
23class Net(nn.Layer):
24"""Simple network used to train on MNIST"""
25
26def __init__(self, use_te=False):
27super().__init__()
28self.conv1 = nn.Conv2D(1, 32, 3, 1)
29self.conv2 = nn.Conv2D(32, 64, 3, 1)
30self.dropout1 = nn.Dropout(0.25)
31self.dropout2 = nn.Dropout(0.5)
32if use_te:
33self.fc1 = te.Linear(9216, 128)
34self.fc2 = te.Linear(128, 16)
35else:
36self.fc1 = nn.Linear(9216, 128)
37self.fc2 = nn.Linear(128, 16)
38self.fc3 = nn.Linear(16, 10)
39
40def forward(self, x):
41"""FWD"""
42x = self.conv1(x)
43x = F.relu(x)
44x = self.conv2(x)
45x = F.relu(x)
46x = F.max_pool2d(x, 2)
47x = self.dropout1(x)
48x = paddle.flatten(x, 1)
49x = self.fc1(x)
50x = F.relu(x)
51x = self.dropout2(x)
52x = self.fc2(x)
53x = self.fc3(x)
54return x
55
56
57def train(args, model, train_loader, optimizer, epoch, use_fp8):
58"""Training function."""
59model.train()
60losses = []
61for batch_id, (data, labels) in enumerate(train_loader):
62with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager
63with te.fp8_autocast(enabled=use_fp8):
64outputs = model(data)
65loss = F.cross_entropy(outputs, labels)
66losses.append(loss.item())
67
68loss.backward()
69optimizer.step()
70optimizer.clear_gradients()
71
72if batch_id % args.log_interval == 0:
73print(f"Train Epoch: {epoch} "
74f"[{batch_id * len(data)}/{len(train_loader.dataset)} "
75f"({100. * batch_id / len(train_loader):.0f}%)]\t"
76f"Loss: {loss.item():.6f}")
77if args.dry_run:
78return loss.item()
79avg_loss = sum(losses) / len(losses)
80print(f"Train Epoch: {epoch}, Average Loss: {avg_loss}")
81return avg_loss
82
83
84def evaluate(model, test_loader, epoch, use_fp8):
85"""Testing function."""
86model.eval()
87metric = Accuracy()
88metric.reset()
89
90with paddle.no_grad():
91for data, labels in test_loader:
92with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager
93with te.fp8_autocast(enabled=use_fp8):
94outputs = model(data)
95acc = metric.compute(outputs, labels)
96metric.update(acc)
97print(f"Epoch[{epoch}] - accuracy: {metric.accumulate():.6f}")
98return metric.accumulate()
99
100
101def calibrate(model, test_loader):
102"""Calibration function."""
103model.eval()
104
105with paddle.no_grad():
106for data, _ in test_loader:
107with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager
108with te.fp8_autocast(enabled=False, calibrating=True):
109_ = model(data)
110
111
112def mnist_parser(args):
113"""Parse training settings"""
114parser = argparse.ArgumentParser(description="Paddle MNIST Example")
115parser.add_argument(
116"--batch-size",
117type=int,
118default=64,
119metavar="N",
120help="input batch size for training (default: 64)",
121)
122parser.add_argument(
123"--test-batch-size",
124type=int,
125default=1000,
126metavar="N",
127help="input batch size for testing (default: 1000)",
128)
129parser.add_argument(
130"--epochs",
131type=int,
132default=14,
133metavar="N",
134help="number of epochs to train (default: 14)",
135)
136parser.add_argument(
137"--lr",
138type=float,
139default=0.001,
140metavar="LR",
141help="learning rate (default: 0.001)",
142)
143parser.add_argument(
144"--dry-run",
145action="store_true",
146default=False,
147help="quickly check a single pass",
148)
149parser.add_argument(
150"--save-model",
151action="store_true",
152default=False,
153help="For Saving the current Model",
154)
155parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
156parser.add_argument(
157"--log-interval",
158type=int,
159default=10,
160metavar="N",
161help="how many batches to wait before logging training status",
162)
163parser.add_argument("--use-fp8",
164action="store_true",
165default=False,
166help="Use FP8 for inference and training without recalibration. " \
167"It also enables Transformer Engine implicitly.")
168parser.add_argument("--use-fp8-infer",
169action="store_true",
170default=False,
171help="Use FP8 for inference only. If not using FP8 for training, "
172"calibration is performed for FP8 infernece.")
173parser.add_argument("--use-te",
174action="store_true",
175default=False,
176help="Use Transformer Engine")
177args = parser.parse_args(args)
178return args
179
180
181def train_and_evaluate(args):
182"""Execute model training and evaluation loop."""
183print(args)
184
185paddle.seed(args.seed)
186
187# Load MNIST dataset
188transform = Normalize(mean=[127.5], std=[127.5], data_format='CHW')
189train_dataset = MNIST(mode='train', transform=transform)
190val_dataset = MNIST(mode='test', transform=transform)
191
192# Define data loaders
193train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
194val_loader = DataLoader(val_dataset, batch_size=args.test_batch_size)
195
196# Define model and optimizer
197model = Net(use_te=args.use_te)
198optimizer = paddle.optimizer.Adam(learning_rate=args.lr, parameters=model.parameters())
199
200# Cast model to BF16
201model = paddle.amp.decorate(models=model, level='O2', dtype='bfloat16')
202
203for epoch in range(1, args.epochs + 1):
204loss = train(args, model, train_loader, optimizer, epoch, args.use_fp8)
205acc = evaluate(model, val_loader, epoch, args.use_fp8)
206
207if args.use_fp8_infer and not args.use_fp8:
208calibrate(model, val_loader)
209
210if args.save_model or args.use_fp8_infer:
211paddle.save(model.state_dict(), "mnist_cnn.pdparams")
212print('Eval with reloaded checkpoint : fp8=' + str(args.use_fp8))
213weights = paddle.load("mnist_cnn.pdparams")
214model.set_state_dict(weights)
215acc = evaluate(model, val_loader, 0, args.use_fp8)
216
217return loss, acc
218
219
220class TestMNIST(unittest.TestCase):
221"""MNIST unittests"""
222
223gpu_has_fp8, reason = is_fp8_available()
224
225@classmethod
226def setUpClass(cls):
227"""Run MNIST without Transformer Engine"""
228cls.args = mnist_parser(["--epochs", "5"])
229
230@staticmethod
231def verify(actual):
232"""Check If loss and accuracy match target"""
233desired_traing_loss = 0.1
234desired_test_accuracy = 0.98
235assert actual[0] < desired_traing_loss
236assert actual[1] > desired_test_accuracy
237
238@unittest.skipIf(paddle.device.cuda.get_device_capability() < (8, 0),
239"BF16 MNIST example requires Ampere+ GPU")
240def test_te_bf16(self):
241"""Test Transformer Engine with BF16"""
242self.args.use_te = True
243self.args.use_fp8 = False
244self.args.save_model = True
245actual = train_and_evaluate(self.args)
246if os.path.exists("mnist_cnn.pdparams"):
247os.remove("mnist_cnn.pdparams")
248self.verify(actual)
249
250@unittest.skipIf(not gpu_has_fp8, reason)
251def test_te_fp8(self):
252"""Test Transformer Engine with FP8"""
253self.args.use_te = True
254self.args.use_fp8 = True
255self.args.save_model = True
256actual = train_and_evaluate(self.args)
257if os.path.exists("mnist_cnn.pdparams"):
258os.remove("mnist_cnn.pdparams")
259self.verify(actual)
260
261@unittest.skipIf(not gpu_has_fp8, reason)
262def test_te_fp8_calibration(self):
263"""Test Transformer Engine with FP8 calibration"""
264self.args.use_te = True
265self.args.use_fp8 = False
266self.args.use_fp8_infer = True
267actual = train_and_evaluate(self.args)
268if os.path.exists("mnist_cnn.pdparams"):
269os.remove("mnist_cnn.pdparams")
270self.verify(actual)
271
272
273if __name__ == "__main__":
274train_and_evaluate(mnist_parser(None))
275