pytorch

Форк
0
/
training_loss.py 
205 строк · 6.4 Кб
1
import argparse
2
import inspect
3
import os
4
import sys
5
import time
6
from datetime import timedelta
7

8
import torch
9

10
import torch._dynamo
11
from datasets import load_dataset, load_metric
12
from torch.utils.data import DataLoader
13
from transformers import AutoModelForSequenceClassification, AutoTokenizer
14

15
torch.backends.cuda.matmul.allow_tf32 = True
16

17
# You will download around 84G dataset if you run this end to end training/evaluation example.
18

19
os.environ["TOKENIZERS_PARALLELISM"] = "false"
20
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
21

22

23
def data_processing(num_samples, batch_size):
24
    dataset = load_dataset("yelp_review_full")
25
    tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
26

27
    def tokenize_function(examples):
28
        return tokenizer(examples["text"], padding="max_length", truncation=True)
29

30
    tokenized_datasets = dataset.map(tokenize_function, batched=True)
31

32
    tokenized_datasets = tokenized_datasets.remove_columns(["text"])
33
    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
34
    tokenized_datasets.set_format("torch")
35

36
    small_train_dataset = tokenized_datasets["train"].select(range(num_samples))
37
    small_eval_dataset = tokenized_datasets["test"].select(range(num_samples))
38

39
    train_dataloader = DataLoader(small_train_dataset, batch_size=batch_size)
40
    eval_dataloader = DataLoader(small_eval_dataset, batch_size=batch_size)
41

42
    return train_dataloader, eval_dataloader
43

44

45
def training_iter_fn(batch, model, optimizer):
46
    outputs = model(**batch)
47
    loss = outputs.loss
48
    loss.backward()
49
    optimizer.step()
50
    optimizer.zero_grad()
51
    return loss
52

53

54
def model_training_evaluation(
55
    backend, train_dataloader, eval_dataloader, model, optimizer, num_epochs, evaluation
56
):
57
    model.to(device)
58
    model.train()
59
    loss_history = []
60
    if not backend:
61
        # Run with native Pytorch
62
        opt_training_iter_fn = training_iter_fn
63
    else:
64
        # Support backends: eager, aot_eager, aot_nvfuser and inductor
65
        opt_training_iter_fn = torch._dynamo.optimize(backend)(training_iter_fn)
66
    for epoch in range(num_epochs):
67
        running_loss = 0.0
68
        for i, batch in enumerate(train_dataloader, 0):
69
            batch = {k: v.to(device) for k, v in batch.items()}
70
            loss = opt_training_iter_fn(batch, model, optimizer)
71
            running_loss += loss.item()
72
            if i % 100 == 99:
73
                loss_history.append(running_loss / 100)
74
                running_loss = 0.0
75

76
    if evaluation:
77
        metric = load_metric("accuracy")
78
        model.eval()
79
        if not backend:
80
            opt_model = model
81
        else:
82
            opt_model = torch._dynamo.optimize(backend)(model)
83
        for batch in eval_dataloader:
84
            batch = {k: v.to(device) for k, v in batch.items()}
85
            with torch.no_grad():
86
                outputs = opt_model(**batch)
87

88
            logits = outputs.logits
89
            predictions = torch.argmax(logits, dim=-1)
90
            metric.add_batch(predictions=predictions, references=batch["labels"])
91

92
        return loss_history, metric.compute()
93
    else:
94
        return loss_history, None
95

96

97
def check_loss(ref_loss, res_loss):
98
    assert len(ref_loss) == len(res_loss)
99
    length = len(ref_loss)
100
    x = min(length, 10)
101
    if sum(res_loss[-x:]) / 10 <= sum(ref_loss[-x:]) / 10 + 1e-1:
102
        return True
103
    else:
104
        return False
105

106

107
def parse_args():
108
    parser = argparse.ArgumentParser(
109
        description="TorchDynamo end to end training/evaluation benchmark"
110
    )
111
    parser.add_argument(
112
        "--epochs", type=int, default=10, help="number of epochs to train (default: 10)"
113
    )
114
    parser.add_argument(
115
        "--num-samples",
116
        type=int,
117
        default=1000,
118
        help="number of samples to train/eval (default: 1000)",
119
    )
120
    parser.add_argument(
121
        "--batch-size",
122
        type=int,
123
        default=8,
124
        help="input batch size for training (default: 8)",
125
    )
126
    parser.add_argument(
127
        "--lr", type=float, default=5e-5, help="learning rate (default: 5e-5)"
128
    )
129
    parser.add_argument(
130
        "--backend",
131
        choices=torch._dynamo.list_backends(exclude_tags=None),
132
        default="inductor",
133
        help="train/evaluate model with a given backend (default: inductor)",
134
    )
135
    parser.add_argument(
136
        "--optimizer",
137
        default="Adam",
138
        help="train model using a given optimizer (default: Adam)",
139
    )
140
    parser.add_argument(
141
        "--evaluation",
142
        action="store_true",
143
        help="running evaluation after model training",
144
    )
145
    args = parser.parse_args()
146
    return args
147

148

149
def main():
150
    args = parse_args()
151
    train_dataloader, eval_dataloader = data_processing(
152
        args.num_samples, args.batch_size
153
    )
154
    model = AutoModelForSequenceClassification.from_pretrained(
155
        "bert-base-cased", num_labels=5
156
    )
157
    optimizer_cls = getattr(sys.modules["torch.optim"], args.optimizer)
158
    if "capturable" in inspect.signature(optimizer_cls).parameters.keys():
159
        optimizer = optimizer_cls(model.parameters(), lr=args.lr, capturable=True)
160
    else:
161
        optimizer = optimizer_cls(model.parameters(), lr=args.lr)
162
    native_start = time.time()
163
    ref_loss, accuracy = model_training_evaluation(
164
        None,
165
        train_dataloader,
166
        eval_dataloader,
167
        model,
168
        optimizer,
169
        args.epochs,
170
        args.evaluation,
171
    )
172
    native_end = time.time()
173
    res_loss, accuracy = model_training_evaluation(
174
        args.backend,
175
        train_dataloader,
176
        eval_dataloader,
177
        model,
178
        optimizer,
179
        args.epochs,
180
        args.evaluation,
181
    )
182
    dynamo_end = time.time()
183
    if check_loss(ref_loss, res_loss):
184
        print(
185
            "[PASSED] TorchDynamo end to end training loss is less than or equal to native PyTorch"
186
        )
187
    else:
188
        print(
189
            "[FAILED] TorchDynamo end to end training loss is greater than native Pytorch"
190
        )
191
    if args.evaluation:
192
        print(f"Model accuracy: {accuracy}")
193
    native_elapsed = native_end - native_start
194
    dynamo_elapsed = dynamo_end - native_end
195
    print(
196
        f"Train model on {args.epochs} epochs with backend {args.backend} and optimizer {args.optimizer}:"
197
    )
198
    print(f"PyTorch spent {timedelta(seconds=native_elapsed/args.epochs)} per epoch")
199
    print(
200
        f"TorchDynamo spent {timedelta(seconds=dynamo_elapsed/args.epochs)} per epoch"
201
    )
202

203

204
if __name__ == "__main__":
205
    main()
206

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.