pytorch

Форк
0
/
huggingface.py 
627 строк · 20.2 Кб
1
#!/usr/bin/env python3
2

3
import importlib
4
import logging
5
import os
6
import re
7
import subprocess
8
import sys
9
import warnings
10

11

12
try:
13
    from .common import (
14
        BenchmarkRunner,
15
        download_retry_decorator,
16
        load_yaml_file,
17
        main,
18
        reset_rng_state,
19
    )
20
except ImportError:
21
    from common import (
22
        BenchmarkRunner,
23
        download_retry_decorator,
24
        load_yaml_file,
25
        main,
26
        reset_rng_state,
27
    )
28

29
import torch
30
from torch._dynamo.testing import collect_results
31
from torch._dynamo.utils import clone_inputs
32

33

34
log = logging.getLogger(__name__)
35

36
# Enable FX graph caching
37
if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:
38
    torch._inductor.config.fx_graph_cache = True
39

40

41
def pip_install(package):
42
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])
43

44

45
# Disable the flake warnings for the imports. Flake8 does not provide a way to
46
# disable just warning for the entire file. Disabling flake8 entirely.
47
# flake8: noqa
48
imports = [
49
    "AlbertForPreTraining",
50
    "AutoConfig",
51
    "AutoModelForCausalLM",
52
    "AutoModelForMaskedLM",
53
    "AutoModelForSeq2SeqLM",
54
    "BigBirdConfig",
55
    "BlenderbotForConditionalGeneration",
56
    "BlenderbotModel",
57
    "BlenderbotSmallForConditionalGeneration",
58
    "BlenderbotSmallModel",
59
    "CLIPModel",
60
    "CLIPVisionModel",
61
    "ElectraForPreTraining",
62
    "GPT2ForSequenceClassification",
63
    "GPTJForSequenceClassification",
64
    "GPTNeoForSequenceClassification",
65
    "HubertForSequenceClassification",
66
    "LxmertForPreTraining",
67
    "LxmertForQuestionAnswering",
68
    "MarianForCausalLM",
69
    "MarianModel",
70
    "MarianMTModel",
71
    "PegasusForConditionalGeneration",
72
    "PegasusModel",
73
    "ReformerConfig",
74
    "ViTForImageClassification",
75
    "ViTForMaskedImageModeling",
76
    "ViTModel",
77
]
78

79

80
def process_hf_reformer_output(out):
81
    assert isinstance(out, list)
82
    # second output is unstable
83
    return [elem for i, elem in enumerate(out) if i != 1]
84

85

86
try:
87
    mod = importlib.import_module("transformers")
88
    for cls in imports:
89
        if not hasattr(mod, cls):
90
            raise ModuleNotFoundError
91
except ModuleNotFoundError:
92
    print("Installing HuggingFace Transformers...")
93
    pip_install("git+https://github.com/huggingface/transformers.git#egg=transformers")
94
finally:
95
    for cls in imports:
96
        exec(f"from transformers import {cls}")
97

98

99
# These models contain the models present in huggingface_models_list. It is a
100
# combination of models supported by HF Fx parser and some manually supplied
101
# models. For these models, we already know the largest batch size that can fit
102
# on A100 GPUs - 40 GB.
103
BATCH_SIZE_KNOWN_MODELS = {}
104

105

106
# TODO(sdym): use batch-size-file parameter of common.main, like torchbench.py
107
# Get the list of models and their batch sizes
108
MODELS_FILENAME = os.path.join(os.path.dirname(__file__), "huggingface_models_list.txt")
109
assert os.path.exists(MODELS_FILENAME)
110
with open(MODELS_FILENAME, "r") as fh:
111
    lines = fh.readlines()
112
    lines = [line.rstrip() for line in lines]
113
    for line in lines:
114
        model_name, batch_size = line.split(",")
115
        batch_size = int(batch_size)
116
        BATCH_SIZE_KNOWN_MODELS[model_name] = batch_size
117
assert len(BATCH_SIZE_KNOWN_MODELS)
118

119

120
def get_module_cls_by_model_name(model_cls_name):
121
    _module_by_model_name = {
122
        "Speech2Text2Decoder": "transformers.models.speech_to_text_2.modeling_speech_to_text_2",
123
        "TrOCRDecoder": "transformers.models.trocr.modeling_trocr",
124
    }
125
    module_name = _module_by_model_name.get(model_cls_name, "transformers")
126
    module = importlib.import_module(module_name)
127
    return getattr(module, model_cls_name)
128

129

130
def get_sequence_length(model_cls, model_name):
131
    if model_name.startswith(("Blenderbot",)):
132
        seq_length = 128
133
    elif model_name.startswith(("GPT2", "Bart", "T5", "PLBart", "MBart")):
134
        seq_length = 1024
135
    elif model_name in ("AllenaiLongformerBase", "BigBird"):
136
        seq_length = 1024
137
    elif model_name.startswith("OPT"):
138
        seq_length = 2048
139
    elif "Reformer" in model_name:
140
        seq_length = 4096
141
    elif model_name.startswith(
142
        (
143
            "Albert",
144
            "Deberta",
145
            "Layout",
146
            "Electra",
147
            "XLNet",
148
            "MegatronBert",
149
            "Bert",
150
            "Roberta",
151
        )
152
    ) or model_name in ("DistillGPT2", "GoogleFnet", "YituTechConvBert", "CamemBert"):
153
        seq_length = 512
154
    elif model_name in ("TrOCRForCausalLM"):
155
        seq_length = 256
156
    elif model_name.startswith("MobileBert"):
157
        seq_length = 128
158
    elif model_name.startswith("Wav2Vec2"):
159
        # If too short, will fail with something like
160
        # ValueError: `mask_length` has to be smaller than `sequence_length`,
161
        # but got `mask_length`: 10 and `sequence_length`: 9`
162
        seq_length = 10000  # NB: a more realistic size is 155136
163
    else:
164
        log.info(
165
            f"Sequence Length not defined for {model_name}. Choosing 128 arbitrarily"
166
        )
167
        seq_length = 128
168
    return seq_length
169

170

171
def generate_inputs_for_model(
172
    model_cls, model, model_name, bs, device, include_loss_args=False
173
):
174
    # TODO - Check if following values are representative
175
    num_choices = 3
176
    num_visual_features = 42
177
    seq_length = get_sequence_length(model_cls, model_name)
178
    vocab_size = model.config.vocab_size
179

180
    if model_name.startswith("Wav2Vec2"):
181
        # TODO: If we add more input_values style models, try to work this
182
        # into the overall control flow
183
        target_length = 100
184
        return {
185
            "input_values": torch.randn((bs, seq_length), device=device),
186
            # Added because that's what the example training script has
187
            "attention_mask": rand_int_tensor(device, 0, 2, (bs, seq_length)),
188
            "labels": rand_int_tensor(device, 0, vocab_size, (bs, target_length)),
189
        }
190

191
    if model_name.endswith("MultipleChoice"):
192
        input = rand_int_tensor(device, 0, vocab_size, (bs, num_choices, seq_length))
193
    elif model_name.startswith("Roberta"):
194
        input = rand_int_tensor(device, 0, 1, (bs, seq_length))
195
    else:
196
        input = rand_int_tensor(device, 0, vocab_size, (bs, seq_length))
197

198
    if "Bart" in model_name:
199
        input[:, -1] = model.config.eos_token_id
200

201
    input_dict = {"input_ids": input}
202

203
    if (
204
        model_name.startswith("T5")
205
        or model_name.startswith("M2M100")
206
        or model_name.startswith("MT5")
207
        or model_cls
208
        in [
209
            BlenderbotModel,
210
            BlenderbotSmallModel,
211
            BlenderbotForConditionalGeneration,
212
            BlenderbotSmallForConditionalGeneration,
213
            PegasusModel,
214
            PegasusForConditionalGeneration,
215
            MarianModel,
216
            MarianMTModel,
217
        ]
218
    ):
219
        input_dict["decoder_input_ids"] = input
220

221
    if model_name.startswith("Lxmert"):
222
        visual_feat_dim, visual_pos_dim = (
223
            model.config.visual_feat_dim,
224
            model.config.visual_pos_dim,
225
        )
226
        input_dict["visual_feats"] = torch.randn(
227
            bs, num_visual_features, visual_feat_dim
228
        )
229
        input_dict["visual_pos"] = torch.randn(bs, num_visual_features, visual_pos_dim)
230

231
    if include_loss_args:
232
        if model_name.endswith("PreTraining"):
233
            if model_cls in [ElectraForPreTraining, LxmertForPreTraining]:
234
                input_dict["labels"] = rand_int_tensor(device, 0, 1, (bs, seq_length))
235
            else:
236
                label_name = (
237
                    "sentence_order_label"
238
                    if model_cls in [AlbertForPreTraining]
239
                    else "next_sentence_label"
240
                )
241
                input_dict["labels"] = (
242
                    rand_int_tensor(device, 0, vocab_size, (bs, seq_length)),
243
                )
244
                input_dict[label_name] = rand_int_tensor(device, 0, 1, (bs,))
245
        elif model_name.endswith("QuestionAnswering"):
246
            input_dict["start_positions"] = rand_int_tensor(
247
                device, 0, seq_length, (bs,)
248
            )
249
            input_dict["end_positions"] = rand_int_tensor(device, 0, seq_length, (bs,))
250
        elif (
251
            model_name.endswith("MaskedLM")
252
            or model_name.endswith("HeadModel")
253
            or model_name.endswith("CausalLM")
254
            or model_name.endswith("DoubleHeadsModel")
255
        ):
256
            input_dict["labels"] = rand_int_tensor(
257
                device, 0, vocab_size, (bs, seq_length)
258
            )
259
        elif model_name.endswith("TokenClassification"):
260
            input_dict["labels"] = rand_int_tensor(
261
                device, 0, model.config.num_labels - 1, (bs, seq_length)
262
            )
263
        elif model_name.endswith("MultipleChoice"):
264
            input_dict["labels"] = rand_int_tensor(device, 0, num_choices, (bs,))
265
        elif model_name.endswith("SequenceClassification"):
266
            input_dict["labels"] = rand_int_tensor(
267
                device, 0, model.config.num_labels - 1, (bs,)
268
            )
269
        elif model_name.endswith("NextSentencePrediction"):
270
            input_dict["labels"] = rand_int_tensor(device, 0, 1, (bs,))
271
        elif model_name.endswith("ForConditionalGeneration"):
272
            input_dict["labels"] = rand_int_tensor(
273
                device, 0, vocab_size - 1, (bs, seq_length)
274
            )
275
        elif model_name in EXTRA_MODELS:
276
            input_dict["labels"] = rand_int_tensor(
277
                device, 0, vocab_size, (bs, seq_length)
278
            )
279
        else:
280
            raise NotImplementedError(
281
                f"Class {model_name} unsupported for training test "
282
            )
283

284
    return input_dict
285

286

287
def rand_int_tensor(device, low, high, shape):
288
    return torch.randint(
289
        low,
290
        high,
291
        shape,
292
        device=device,
293
        dtype=torch.int64,
294
        requires_grad=False,
295
    )
296

297

298
EXTRA_MODELS = {
299
    "AllenaiLongformerBase": (
300
        AutoConfig.from_pretrained("allenai/longformer-base-4096"),
301
        AutoModelForMaskedLM,
302
    ),
303
    "Reformer": (
304
        ReformerConfig(),
305
        AutoModelForMaskedLM,
306
    ),
307
    "T5Small": (
308
        AutoConfig.from_pretrained("t5-small"),
309
        AutoModelForSeq2SeqLM,
310
    ),
311
    # "BigBird": (
312
    #     BigBirdConfig(attention_type="block_sparse"),
313
    #     AutoModelForMaskedLM,
314
    # ),
315
    "DistillGPT2": (
316
        AutoConfig.from_pretrained("distilgpt2"),
317
        AutoModelForCausalLM,
318
    ),
319
    "GoogleFnet": (
320
        AutoConfig.from_pretrained("google/fnet-base"),
321
        AutoModelForMaskedLM,
322
    ),
323
    "YituTechConvBert": (
324
        AutoConfig.from_pretrained("YituTech/conv-bert-base"),
325
        AutoModelForMaskedLM,
326
    ),
327
    "CamemBert": (
328
        AutoConfig.from_pretrained("camembert-base"),
329
        AutoModelForMaskedLM,
330
    ),
331
}
332

333

334
class HuggingfaceRunner(BenchmarkRunner):
335
    def __init__(self):
336
        super().__init__()
337
        self.suite_name = "huggingface"
338

339
    @property
340
    def _config(self):
341
        return load_yaml_file("huggingface.yaml")
342

343
    @property
344
    def _skip(self):
345
        return self._config["skip"]
346

347
    @property
348
    def _accuracy(self):
349
        return self._config["accuracy"]
350

351
    @property
352
    def skip_models(self):
353
        return self._skip["all"]
354

355
    @property
356
    def skip_models_for_cpu(self):
357
        return self._skip["device"]["cpu"]
358

359
    @property
360
    def fp32_only_models(self):
361
        return self._config["only_fp32"]
362

363
    @property
364
    def skip_models_due_to_control_flow(self):
365
        return self._skip["control_flow"]
366

367
    def _get_model_cls_and_config(self, model_name):
368
        if model_name not in EXTRA_MODELS:
369
            model_cls = get_module_cls_by_model_name(model_name)
370
            config_cls = model_cls.config_class
371
            config = config_cls()
372

373
            # NB: some models need a pad token defined to handle BS > 1
374
            if (
375
                model_cls
376
                in [
377
                    GPT2ForSequenceClassification,
378
                    GPTNeoForSequenceClassification,
379
                    GPTJForSequenceClassification,
380
                ]
381
                or model_cls.__name__.startswith("Roberta")
382
                or model_cls.__name__.startswith("Marian")
383
            ):
384
                config.pad_token_id = 0
385

386
        else:
387
            config, model_cls = EXTRA_MODELS[model_name]
388

389
        return model_cls, config
390

391
    @download_retry_decorator
392
    def _download_model(self, model_name):
393
        model_cls, config = self._get_model_cls_and_config(model_name)
394
        if "auto" in model_cls.__module__:
395
            # Handle auto classes
396
            model = model_cls.from_config(config)
397
        else:
398
            model = model_cls(config)
399
        return model
400

401
    def load_model(
402
        self,
403
        device,
404
        model_name,
405
        batch_size=None,
406
        extra_args=None,
407
    ):
408
        is_training = self.args.training
409
        use_eval_mode = self.args.use_eval_mode
410
        dtype = torch.float32
411
        reset_rng_state()
412
        model_cls, config = self._get_model_cls_and_config(model_name)
413
        model = self._download_model(model_name)
414
        model = model.to(device, dtype=dtype)
415
        if self.args.enable_activation_checkpointing:
416
            model.gradient_checkpointing_enable()
417
        if model_name in BATCH_SIZE_KNOWN_MODELS:
418
            batch_size_default = BATCH_SIZE_KNOWN_MODELS[model_name]
419
        elif batch_size is None:
420
            batch_size_default = 16
421
            log.info(
422
                f"Batch size not specified for {model_name}. Setting batch_size=16"
423
            )
424

425
        if batch_size is None:
426
            batch_size = batch_size_default
427
            batch_size_divisors = self._config["batch_size"]["divisors"]
428
            if model_name in batch_size_divisors:
429
                batch_size = max(int(batch_size / batch_size_divisors[model_name]), 1)
430
                log.info(
431
                    f"Running smaller batch size={batch_size} for {model_name}, orig batch_size={batch_size_default}"
432
                )
433

434
        example_inputs = generate_inputs_for_model(
435
            model_cls, model, model_name, batch_size, device, include_loss_args=True
436
        )
437

438
        # So we can check for correct gradients without eliminating the dropout computation
439
        for attr in dir(config):
440
            if "drop" in attr and isinstance(getattr(config, attr), float):
441
                setattr(config, attr, 1e-30)
442

443
        if (
444
            is_training
445
            and not use_eval_mode
446
            and not (
447
                self.args.accuracy and model_name in self._config["only_inference"]
448
            )
449
        ):
450
            model.train()
451
        else:
452
            model.eval()
453

454
        self.validate_model(model, example_inputs)
455
        return device, model_name, model, example_inputs, batch_size
456

457
    def iter_model_names(self, args):
458
        model_names = list(BATCH_SIZE_KNOWN_MODELS.keys()) + list(EXTRA_MODELS.keys())
459
        model_names = set(model_names)
460
        model_names = sorted(model_names)
461

462
        start, end = self.get_benchmark_indices(len(model_names))
463
        for index, model_name in enumerate(model_names):
464
            if index < start or index >= end:
465
                continue
466
            if (
467
                not re.search("|".join(args.filter), model_name, re.I)
468
                or re.search("|".join(args.exclude), model_name, re.I)
469
                or model_name in args.exclude_exact
470
                or model_name in self.skip_models
471
            ):
472
                continue
473
            yield model_name
474

475
    @property
476
    def skip_accuracy_checks_large_models_dashboard(self):
477
        if self.args.dashboard or self.args.accuracy:
478
            return self._accuracy["skip"]["large_models"]
479
        return set()
480

481
    @property
482
    def get_output_amp_train_process_func(self):
483
        return {}
484

485
    def pick_grad(self, name, is_training):
486
        if is_training:
487
            return torch.enable_grad()
488
        else:
489
            return torch.no_grad()
490

491
    def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
492
        cosine = self.args.cosine
493
        if is_training:
494
            from torch._inductor import config as inductor_config
495

496
            if (name in self._config["tolerance"]["higher_training"]) or (
497
                inductor_config.max_autotune
498
                and name in self._config["tolerance"]["higher_max_autotune_training"]
499
            ):
500
                return 2e-2, cosine
501
            else:
502
                return 1e-2, cosine
503
        else:
504
            if name in self._config["tolerance"]["higher_inference"]:
505
                return 4e-3, cosine
506
            if (
507
                current_device == "cpu"
508
                and name in self._config["tolerance"]["higher_inference"]
509
            ):
510
                return 4e-3, cosine
511
        return 1e-3, cosine
512

513
    def compute_loss(self, pred):
514
        return pred[0]
515

516
    def forward_pass(self, mod, inputs, collect_outputs=True):
517
        with self.autocast(**self.autocast_arg):
518
            return mod(**inputs)
519

520
    def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
521
        cloned_inputs = clone_inputs(inputs)
522
        self.optimizer_zero_grad(mod)
523
        with self.autocast(**self.autocast_arg):
524
            pred = mod(**cloned_inputs)
525
            loss = self.compute_loss(pred)
526
        self.grad_scaler.scale(loss).backward()
527
        self.optimizer_step()
528
        if collect_outputs:
529
            return collect_results(mod, pred, loss, cloned_inputs)
530
        return None
531

532

533
def refresh_model_names_and_batch_sizes():
534
    """
535
    This function reads the HF Fx tracer supported models and finds the largest
536
    batch size that could fit on the GPU with PyTorch eager.
537

538
    The resulting data is written in huggingface_models_list.txt.
539

540
    Note - We only need to run this function if we believe that HF Fx tracer now
541
    supports more models.
542
    """
543
    import transformers.utils.fx as hf_fx
544

545
    family = {}
546
    lm_seen = set()
547
    family_seen = set()
548
    for cls_name in hf_fx._SUPPORTED_MODELS:
549
        if "For" not in cls_name:
550
            continue
551

552
        model_cls = get_module_cls_by_model_name(cls_name)
553

554
        # TODO: AttributeError: '*Config' object has no attribute 'vocab_size'
555
        if model_cls in [
556
            CLIPModel,
557
            CLIPVisionModel,
558
            # SwinForImageClassification,
559
            # SwinForImageClassification,
560
            # SwinForMaskedImageModeling,
561
            # SwinModel,
562
            ViTForImageClassification,
563
            ViTForMaskedImageModeling,
564
            ViTModel,
565
        ]:
566
            continue
567

568
        # TODO: AssertionError: Padding_idx must be within num_embeddings
569
        if model_cls in [MarianForCausalLM, MarianMTModel, MarianModel]:
570
            continue
571

572
        # TODO: "model is not supported yet" from HFTracer
573
        if model_cls in [HubertForSequenceClassification]:
574
            continue
575

576
        # TODO: shape mismatch in loss calculation
577
        if model_cls in [LxmertForQuestionAnswering]:
578
            continue
579

580
        family_name = cls_name.split("For")[0]
581
        if family_name not in family:
582
            family[family_name] = []
583
        if cls_name.endswith(("MaskedLM", "CausalLM")) and family_name not in lm_seen:
584
            family[family_name].append(cls_name)
585
            lm_seen.add(family_name)
586
        elif (
587
            cls_name.endswith(
588
                ("SequenceClassification", "ConditionalGeneration", "QuestionAnswering")
589
            )
590
            and family_name not in family_seen
591
        ):
592
            family[family_name].append(cls_name)
593
            family_seen.add(family_name)
594
        elif cls_name.endswith("ImageClassification"):
595
            family[family_name].append(cls_name)
596

597
    chosen_models = set()
598
    for members in family.values():
599
        chosen_models.update(set(members))
600

601
    # Add the EXTRA_MODELS
602
    chosen_models.update(set(EXTRA_MODELS.keys()))
603

604
    for model_name in sorted(chosen_models):
605
        try:
606
            subprocess.check_call(
607
                [sys.executable]
608
                + sys.argv
609
                + ["--find-batch-sizes"]
610
                + [f"--only={model_name}"]
611
                + [f"--output={MODELS_FILENAME}"]
612
            )
613
        except subprocess.SubprocessError:
614
            log.warning(f"Failed to find suitable batch size for {model_name}")
615

616

617
def huggingface_main():
618
    # Code to refresh model names and batch sizes
619
    # if "--find-batch-sizes" not in sys.argv:
620
    #     refresh_model_names_and_batch_sizes()
621
    logging.basicConfig(level=logging.WARNING)
622
    warnings.filterwarnings("ignore")
623
    main(HuggingfaceRunner())
624

625

626
if __name__ == "__main__":
627
    huggingface_main()
628

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.