pytorch
627 строк · 20.2 Кб
1#!/usr/bin/env python3
2
3import importlib4import logging5import os6import re7import subprocess8import sys9import warnings10
11
12try:13from .common import (14BenchmarkRunner,15download_retry_decorator,16load_yaml_file,17main,18reset_rng_state,19)20except ImportError:21from common import (22BenchmarkRunner,23download_retry_decorator,24load_yaml_file,25main,26reset_rng_state,27)28
29import torch30from torch._dynamo.testing import collect_results31from torch._dynamo.utils import clone_inputs32
33
34log = logging.getLogger(__name__)35
36# Enable FX graph caching
37if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:38torch._inductor.config.fx_graph_cache = True39
40
41def pip_install(package):42subprocess.check_call([sys.executable, "-m", "pip", "install", package])43
44
45# Disable the flake warnings for the imports. Flake8 does not provide a way to
46# disable just warning for the entire file. Disabling flake8 entirely.
47# flake8: noqa
48imports = [49"AlbertForPreTraining",50"AutoConfig",51"AutoModelForCausalLM",52"AutoModelForMaskedLM",53"AutoModelForSeq2SeqLM",54"BigBirdConfig",55"BlenderbotForConditionalGeneration",56"BlenderbotModel",57"BlenderbotSmallForConditionalGeneration",58"BlenderbotSmallModel",59"CLIPModel",60"CLIPVisionModel",61"ElectraForPreTraining",62"GPT2ForSequenceClassification",63"GPTJForSequenceClassification",64"GPTNeoForSequenceClassification",65"HubertForSequenceClassification",66"LxmertForPreTraining",67"LxmertForQuestionAnswering",68"MarianForCausalLM",69"MarianModel",70"MarianMTModel",71"PegasusForConditionalGeneration",72"PegasusModel",73"ReformerConfig",74"ViTForImageClassification",75"ViTForMaskedImageModeling",76"ViTModel",77]
78
79
80def process_hf_reformer_output(out):81assert isinstance(out, list)82# second output is unstable83return [elem for i, elem in enumerate(out) if i != 1]84
85
86try:87mod = importlib.import_module("transformers")88for cls in imports:89if not hasattr(mod, cls):90raise ModuleNotFoundError91except ModuleNotFoundError:92print("Installing HuggingFace Transformers...")93pip_install("git+https://github.com/huggingface/transformers.git#egg=transformers")94finally:95for cls in imports:96exec(f"from transformers import {cls}")97
98
99# These models contain the models present in huggingface_models_list. It is a
100# combination of models supported by HF Fx parser and some manually supplied
101# models. For these models, we already know the largest batch size that can fit
102# on A100 GPUs - 40 GB.
103BATCH_SIZE_KNOWN_MODELS = {}104
105
106# TODO(sdym): use batch-size-file parameter of common.main, like torchbench.py
107# Get the list of models and their batch sizes
108MODELS_FILENAME = os.path.join(os.path.dirname(__file__), "huggingface_models_list.txt")109assert os.path.exists(MODELS_FILENAME)110with open(MODELS_FILENAME, "r") as fh:111lines = fh.readlines()112lines = [line.rstrip() for line in lines]113for line in lines:114model_name, batch_size = line.split(",")115batch_size = int(batch_size)116BATCH_SIZE_KNOWN_MODELS[model_name] = batch_size117assert len(BATCH_SIZE_KNOWN_MODELS)118
119
120def get_module_cls_by_model_name(model_cls_name):121_module_by_model_name = {122"Speech2Text2Decoder": "transformers.models.speech_to_text_2.modeling_speech_to_text_2",123"TrOCRDecoder": "transformers.models.trocr.modeling_trocr",124}125module_name = _module_by_model_name.get(model_cls_name, "transformers")126module = importlib.import_module(module_name)127return getattr(module, model_cls_name)128
129
130def get_sequence_length(model_cls, model_name):131if model_name.startswith(("Blenderbot",)):132seq_length = 128133elif model_name.startswith(("GPT2", "Bart", "T5", "PLBart", "MBart")):134seq_length = 1024135elif model_name in ("AllenaiLongformerBase", "BigBird"):136seq_length = 1024137elif model_name.startswith("OPT"):138seq_length = 2048139elif "Reformer" in model_name:140seq_length = 4096141elif model_name.startswith(142(143"Albert",144"Deberta",145"Layout",146"Electra",147"XLNet",148"MegatronBert",149"Bert",150"Roberta",151)152) or model_name in ("DistillGPT2", "GoogleFnet", "YituTechConvBert", "CamemBert"):153seq_length = 512154elif model_name in ("TrOCRForCausalLM"):155seq_length = 256156elif model_name.startswith("MobileBert"):157seq_length = 128158elif model_name.startswith("Wav2Vec2"):159# If too short, will fail with something like160# ValueError: `mask_length` has to be smaller than `sequence_length`,161# but got `mask_length`: 10 and `sequence_length`: 9`162seq_length = 10000 # NB: a more realistic size is 155136163else:164log.info(165f"Sequence Length not defined for {model_name}. Choosing 128 arbitrarily"166)167seq_length = 128168return seq_length169
170
171def generate_inputs_for_model(172model_cls, model, model_name, bs, device, include_loss_args=False173):174# TODO - Check if following values are representative175num_choices = 3176num_visual_features = 42177seq_length = get_sequence_length(model_cls, model_name)178vocab_size = model.config.vocab_size179
180if model_name.startswith("Wav2Vec2"):181# TODO: If we add more input_values style models, try to work this182# into the overall control flow183target_length = 100184return {185"input_values": torch.randn((bs, seq_length), device=device),186# Added because that's what the example training script has187"attention_mask": rand_int_tensor(device, 0, 2, (bs, seq_length)),188"labels": rand_int_tensor(device, 0, vocab_size, (bs, target_length)),189}190
191if model_name.endswith("MultipleChoice"):192input = rand_int_tensor(device, 0, vocab_size, (bs, num_choices, seq_length))193elif model_name.startswith("Roberta"):194input = rand_int_tensor(device, 0, 1, (bs, seq_length))195else:196input = rand_int_tensor(device, 0, vocab_size, (bs, seq_length))197
198if "Bart" in model_name:199input[:, -1] = model.config.eos_token_id200
201input_dict = {"input_ids": input}202
203if (204model_name.startswith("T5")205or model_name.startswith("M2M100")206or model_name.startswith("MT5")207or model_cls208in [209BlenderbotModel,210BlenderbotSmallModel,211BlenderbotForConditionalGeneration,212BlenderbotSmallForConditionalGeneration,213PegasusModel,214PegasusForConditionalGeneration,215MarianModel,216MarianMTModel,217]218):219input_dict["decoder_input_ids"] = input220
221if model_name.startswith("Lxmert"):222visual_feat_dim, visual_pos_dim = (223model.config.visual_feat_dim,224model.config.visual_pos_dim,225)226input_dict["visual_feats"] = torch.randn(227bs, num_visual_features, visual_feat_dim228)229input_dict["visual_pos"] = torch.randn(bs, num_visual_features, visual_pos_dim)230
231if include_loss_args:232if model_name.endswith("PreTraining"):233if model_cls in [ElectraForPreTraining, LxmertForPreTraining]:234input_dict["labels"] = rand_int_tensor(device, 0, 1, (bs, seq_length))235else:236label_name = (237"sentence_order_label"238if model_cls in [AlbertForPreTraining]239else "next_sentence_label"240)241input_dict["labels"] = (242rand_int_tensor(device, 0, vocab_size, (bs, seq_length)),243)244input_dict[label_name] = rand_int_tensor(device, 0, 1, (bs,))245elif model_name.endswith("QuestionAnswering"):246input_dict["start_positions"] = rand_int_tensor(247device, 0, seq_length, (bs,)248)249input_dict["end_positions"] = rand_int_tensor(device, 0, seq_length, (bs,))250elif (251model_name.endswith("MaskedLM")252or model_name.endswith("HeadModel")253or model_name.endswith("CausalLM")254or model_name.endswith("DoubleHeadsModel")255):256input_dict["labels"] = rand_int_tensor(257device, 0, vocab_size, (bs, seq_length)258)259elif model_name.endswith("TokenClassification"):260input_dict["labels"] = rand_int_tensor(261device, 0, model.config.num_labels - 1, (bs, seq_length)262)263elif model_name.endswith("MultipleChoice"):264input_dict["labels"] = rand_int_tensor(device, 0, num_choices, (bs,))265elif model_name.endswith("SequenceClassification"):266input_dict["labels"] = rand_int_tensor(267device, 0, model.config.num_labels - 1, (bs,)268)269elif model_name.endswith("NextSentencePrediction"):270input_dict["labels"] = rand_int_tensor(device, 0, 1, (bs,))271elif model_name.endswith("ForConditionalGeneration"):272input_dict["labels"] = rand_int_tensor(273device, 0, vocab_size - 1, (bs, seq_length)274)275elif model_name in EXTRA_MODELS:276input_dict["labels"] = rand_int_tensor(277device, 0, vocab_size, (bs, seq_length)278)279else:280raise NotImplementedError(281f"Class {model_name} unsupported for training test "282)283
284return input_dict285
286
287def rand_int_tensor(device, low, high, shape):288return torch.randint(289low,290high,291shape,292device=device,293dtype=torch.int64,294requires_grad=False,295)296
297
298EXTRA_MODELS = {299"AllenaiLongformerBase": (300AutoConfig.from_pretrained("allenai/longformer-base-4096"),301AutoModelForMaskedLM,302),303"Reformer": (304ReformerConfig(),305AutoModelForMaskedLM,306),307"T5Small": (308AutoConfig.from_pretrained("t5-small"),309AutoModelForSeq2SeqLM,310),311# "BigBird": (312# BigBirdConfig(attention_type="block_sparse"),313# AutoModelForMaskedLM,314# ),315"DistillGPT2": (316AutoConfig.from_pretrained("distilgpt2"),317AutoModelForCausalLM,318),319"GoogleFnet": (320AutoConfig.from_pretrained("google/fnet-base"),321AutoModelForMaskedLM,322),323"YituTechConvBert": (324AutoConfig.from_pretrained("YituTech/conv-bert-base"),325AutoModelForMaskedLM,326),327"CamemBert": (328AutoConfig.from_pretrained("camembert-base"),329AutoModelForMaskedLM,330),331}
332
333
334class HuggingfaceRunner(BenchmarkRunner):335def __init__(self):336super().__init__()337self.suite_name = "huggingface"338
339@property340def _config(self):341return load_yaml_file("huggingface.yaml")342
343@property344def _skip(self):345return self._config["skip"]346
347@property348def _accuracy(self):349return self._config["accuracy"]350
351@property352def skip_models(self):353return self._skip["all"]354
355@property356def skip_models_for_cpu(self):357return self._skip["device"]["cpu"]358
359@property360def fp32_only_models(self):361return self._config["only_fp32"]362
363@property364def skip_models_due_to_control_flow(self):365return self._skip["control_flow"]366
367def _get_model_cls_and_config(self, model_name):368if model_name not in EXTRA_MODELS:369model_cls = get_module_cls_by_model_name(model_name)370config_cls = model_cls.config_class371config = config_cls()372
373# NB: some models need a pad token defined to handle BS > 1374if (375model_cls
376in [377GPT2ForSequenceClassification,378GPTNeoForSequenceClassification,379GPTJForSequenceClassification,380]381or model_cls.__name__.startswith("Roberta")382or model_cls.__name__.startswith("Marian")383):384config.pad_token_id = 0385
386else:387config, model_cls = EXTRA_MODELS[model_name]388
389return model_cls, config390
391@download_retry_decorator392def _download_model(self, model_name):393model_cls, config = self._get_model_cls_and_config(model_name)394if "auto" in model_cls.__module__:395# Handle auto classes396model = model_cls.from_config(config)397else:398model = model_cls(config)399return model400
401def load_model(402self,403device,404model_name,405batch_size=None,406extra_args=None,407):408is_training = self.args.training409use_eval_mode = self.args.use_eval_mode410dtype = torch.float32411reset_rng_state()412model_cls, config = self._get_model_cls_and_config(model_name)413model = self._download_model(model_name)414model = model.to(device, dtype=dtype)415if self.args.enable_activation_checkpointing:416model.gradient_checkpointing_enable()417if model_name in BATCH_SIZE_KNOWN_MODELS:418batch_size_default = BATCH_SIZE_KNOWN_MODELS[model_name]419elif batch_size is None:420batch_size_default = 16421log.info(422f"Batch size not specified for {model_name}. Setting batch_size=16"423)424
425if batch_size is None:426batch_size = batch_size_default427batch_size_divisors = self._config["batch_size"]["divisors"]428if model_name in batch_size_divisors:429batch_size = max(int(batch_size / batch_size_divisors[model_name]), 1)430log.info(431f"Running smaller batch size={batch_size} for {model_name}, orig batch_size={batch_size_default}"432)433
434example_inputs = generate_inputs_for_model(435model_cls, model, model_name, batch_size, device, include_loss_args=True436)437
438# So we can check for correct gradients without eliminating the dropout computation439for attr in dir(config):440if "drop" in attr and isinstance(getattr(config, attr), float):441setattr(config, attr, 1e-30)442
443if (444is_training
445and not use_eval_mode446and not (447self.args.accuracy and model_name in self._config["only_inference"]448)449):450model.train()451else:452model.eval()453
454self.validate_model(model, example_inputs)455return device, model_name, model, example_inputs, batch_size456
457def iter_model_names(self, args):458model_names = list(BATCH_SIZE_KNOWN_MODELS.keys()) + list(EXTRA_MODELS.keys())459model_names = set(model_names)460model_names = sorted(model_names)461
462start, end = self.get_benchmark_indices(len(model_names))463for index, model_name in enumerate(model_names):464if index < start or index >= end:465continue466if (467not re.search("|".join(args.filter), model_name, re.I)468or re.search("|".join(args.exclude), model_name, re.I)469or model_name in args.exclude_exact470or model_name in self.skip_models471):472continue473yield model_name474
475@property476def skip_accuracy_checks_large_models_dashboard(self):477if self.args.dashboard or self.args.accuracy:478return self._accuracy["skip"]["large_models"]479return set()480
481@property482def get_output_amp_train_process_func(self):483return {}484
485def pick_grad(self, name, is_training):486if is_training:487return torch.enable_grad()488else:489return torch.no_grad()490
491def get_tolerance_and_cosine_flag(self, is_training, current_device, name):492cosine = self.args.cosine493if is_training:494from torch._inductor import config as inductor_config495
496if (name in self._config["tolerance"]["higher_training"]) or (497inductor_config.max_autotune498and name in self._config["tolerance"]["higher_max_autotune_training"]499):500return 2e-2, cosine501else:502return 1e-2, cosine503else:504if name in self._config["tolerance"]["higher_inference"]:505return 4e-3, cosine506if (507current_device == "cpu"508and name in self._config["tolerance"]["higher_inference"]509):510return 4e-3, cosine511return 1e-3, cosine512
513def compute_loss(self, pred):514return pred[0]515
516def forward_pass(self, mod, inputs, collect_outputs=True):517with self.autocast(**self.autocast_arg):518return mod(**inputs)519
520def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):521cloned_inputs = clone_inputs(inputs)522self.optimizer_zero_grad(mod)523with self.autocast(**self.autocast_arg):524pred = mod(**cloned_inputs)525loss = self.compute_loss(pred)526self.grad_scaler.scale(loss).backward()527self.optimizer_step()528if collect_outputs:529return collect_results(mod, pred, loss, cloned_inputs)530return None531
532
533def refresh_model_names_and_batch_sizes():534"""535This function reads the HF Fx tracer supported models and finds the largest
536batch size that could fit on the GPU with PyTorch eager.
537
538The resulting data is written in huggingface_models_list.txt.
539
540Note - We only need to run this function if we believe that HF Fx tracer now
541supports more models.
542"""
543import transformers.utils.fx as hf_fx544
545family = {}546lm_seen = set()547family_seen = set()548for cls_name in hf_fx._SUPPORTED_MODELS:549if "For" not in cls_name:550continue551
552model_cls = get_module_cls_by_model_name(cls_name)553
554# TODO: AttributeError: '*Config' object has no attribute 'vocab_size'555if model_cls in [556CLIPModel,557CLIPVisionModel,558# SwinForImageClassification,559# SwinForImageClassification,560# SwinForMaskedImageModeling,561# SwinModel,562ViTForImageClassification,563ViTForMaskedImageModeling,564ViTModel,565]:566continue567
568# TODO: AssertionError: Padding_idx must be within num_embeddings569if model_cls in [MarianForCausalLM, MarianMTModel, MarianModel]:570continue571
572# TODO: "model is not supported yet" from HFTracer573if model_cls in [HubertForSequenceClassification]:574continue575
576# TODO: shape mismatch in loss calculation577if model_cls in [LxmertForQuestionAnswering]:578continue579
580family_name = cls_name.split("For")[0]581if family_name not in family:582family[family_name] = []583if cls_name.endswith(("MaskedLM", "CausalLM")) and family_name not in lm_seen:584family[family_name].append(cls_name)585lm_seen.add(family_name)586elif (587cls_name.endswith(588("SequenceClassification", "ConditionalGeneration", "QuestionAnswering")589)590and family_name not in family_seen591):592family[family_name].append(cls_name)593family_seen.add(family_name)594elif cls_name.endswith("ImageClassification"):595family[family_name].append(cls_name)596
597chosen_models = set()598for members in family.values():599chosen_models.update(set(members))600
601# Add the EXTRA_MODELS602chosen_models.update(set(EXTRA_MODELS.keys()))603
604for model_name in sorted(chosen_models):605try:606subprocess.check_call(607[sys.executable]608+ sys.argv609+ ["--find-batch-sizes"]610+ [f"--only={model_name}"]611+ [f"--output={MODELS_FILENAME}"]612)613except subprocess.SubprocessError:614log.warning(f"Failed to find suitable batch size for {model_name}")615
616
617def huggingface_main():618# Code to refresh model names and batch sizes619# if "--find-batch-sizes" not in sys.argv:620# refresh_model_names_and_batch_sizes()621logging.basicConfig(level=logging.WARNING)622warnings.filterwarnings("ignore")623main(HuggingfaceRunner())624
625
626if __name__ == "__main__":627huggingface_main()628