pytorch
205 строк · 6.4 Кб
1import argparse
2import inspect
3import os
4import sys
5import time
6from datetime import timedelta
7
8import torch
9
10import torch._dynamo
11from datasets import load_dataset, load_metric
12from torch.utils.data import DataLoader
13from transformers import AutoModelForSequenceClassification, AutoTokenizer
14
15torch.backends.cuda.matmul.allow_tf32 = True
16
17# You will download around 84G dataset if you run this end to end training/evaluation example.
18
19os.environ["TOKENIZERS_PARALLELISM"] = "false"
20device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
21
22
23def data_processing(num_samples, batch_size):
24dataset = load_dataset("yelp_review_full")
25tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
26
27def tokenize_function(examples):
28return tokenizer(examples["text"], padding="max_length", truncation=True)
29
30tokenized_datasets = dataset.map(tokenize_function, batched=True)
31
32tokenized_datasets = tokenized_datasets.remove_columns(["text"])
33tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
34tokenized_datasets.set_format("torch")
35
36small_train_dataset = tokenized_datasets["train"].select(range(num_samples))
37small_eval_dataset = tokenized_datasets["test"].select(range(num_samples))
38
39train_dataloader = DataLoader(small_train_dataset, batch_size=batch_size)
40eval_dataloader = DataLoader(small_eval_dataset, batch_size=batch_size)
41
42return train_dataloader, eval_dataloader
43
44
45def training_iter_fn(batch, model, optimizer):
46outputs = model(**batch)
47loss = outputs.loss
48loss.backward()
49optimizer.step()
50optimizer.zero_grad()
51return loss
52
53
54def model_training_evaluation(
55backend, train_dataloader, eval_dataloader, model, optimizer, num_epochs, evaluation
56):
57model.to(device)
58model.train()
59loss_history = []
60if not backend:
61# Run with native Pytorch
62opt_training_iter_fn = training_iter_fn
63else:
64# Support backends: eager, aot_eager, aot_nvfuser and inductor
65opt_training_iter_fn = torch._dynamo.optimize(backend)(training_iter_fn)
66for epoch in range(num_epochs):
67running_loss = 0.0
68for i, batch in enumerate(train_dataloader, 0):
69batch = {k: v.to(device) for k, v in batch.items()}
70loss = opt_training_iter_fn(batch, model, optimizer)
71running_loss += loss.item()
72if i % 100 == 99:
73loss_history.append(running_loss / 100)
74running_loss = 0.0
75
76if evaluation:
77metric = load_metric("accuracy")
78model.eval()
79if not backend:
80opt_model = model
81else:
82opt_model = torch._dynamo.optimize(backend)(model)
83for batch in eval_dataloader:
84batch = {k: v.to(device) for k, v in batch.items()}
85with torch.no_grad():
86outputs = opt_model(**batch)
87
88logits = outputs.logits
89predictions = torch.argmax(logits, dim=-1)
90metric.add_batch(predictions=predictions, references=batch["labels"])
91
92return loss_history, metric.compute()
93else:
94return loss_history, None
95
96
97def check_loss(ref_loss, res_loss):
98assert len(ref_loss) == len(res_loss)
99length = len(ref_loss)
100x = min(length, 10)
101if sum(res_loss[-x:]) / 10 <= sum(ref_loss[-x:]) / 10 + 1e-1:
102return True
103else:
104return False
105
106
107def parse_args():
108parser = argparse.ArgumentParser(
109description="TorchDynamo end to end training/evaluation benchmark"
110)
111parser.add_argument(
112"--epochs", type=int, default=10, help="number of epochs to train (default: 10)"
113)
114parser.add_argument(
115"--num-samples",
116type=int,
117default=1000,
118help="number of samples to train/eval (default: 1000)",
119)
120parser.add_argument(
121"--batch-size",
122type=int,
123default=8,
124help="input batch size for training (default: 8)",
125)
126parser.add_argument(
127"--lr", type=float, default=5e-5, help="learning rate (default: 5e-5)"
128)
129parser.add_argument(
130"--backend",
131choices=torch._dynamo.list_backends(exclude_tags=None),
132default="inductor",
133help="train/evaluate model with a given backend (default: inductor)",
134)
135parser.add_argument(
136"--optimizer",
137default="Adam",
138help="train model using a given optimizer (default: Adam)",
139)
140parser.add_argument(
141"--evaluation",
142action="store_true",
143help="running evaluation after model training",
144)
145args = parser.parse_args()
146return args
147
148
149def main():
150args = parse_args()
151train_dataloader, eval_dataloader = data_processing(
152args.num_samples, args.batch_size
153)
154model = AutoModelForSequenceClassification.from_pretrained(
155"bert-base-cased", num_labels=5
156)
157optimizer_cls = getattr(sys.modules["torch.optim"], args.optimizer)
158if "capturable" in inspect.signature(optimizer_cls).parameters.keys():
159optimizer = optimizer_cls(model.parameters(), lr=args.lr, capturable=True)
160else:
161optimizer = optimizer_cls(model.parameters(), lr=args.lr)
162native_start = time.time()
163ref_loss, accuracy = model_training_evaluation(
164None,
165train_dataloader,
166eval_dataloader,
167model,
168optimizer,
169args.epochs,
170args.evaluation,
171)
172native_end = time.time()
173res_loss, accuracy = model_training_evaluation(
174args.backend,
175train_dataloader,
176eval_dataloader,
177model,
178optimizer,
179args.epochs,
180args.evaluation,
181)
182dynamo_end = time.time()
183if check_loss(ref_loss, res_loss):
184print(
185"[PASSED] TorchDynamo end to end training loss is less than or equal to native PyTorch"
186)
187else:
188print(
189"[FAILED] TorchDynamo end to end training loss is greater than native Pytorch"
190)
191if args.evaluation:
192print(f"Model accuracy: {accuracy}")
193native_elapsed = native_end - native_start
194dynamo_elapsed = dynamo_end - native_end
195print(
196f"Train model on {args.epochs} epochs with backend {args.backend} and optimizer {args.optimizer}:"
197)
198print(f"PyTorch spent {timedelta(seconds=native_elapsed/args.epochs)} per epoch")
199print(
200f"TorchDynamo spent {timedelta(seconds=dynamo_elapsed/args.epochs)} per epoch"
201)
202
203
204if __name__ == "__main__":
205main()
206