pytorch
422 строки · 11.9 Кб
1#!/usr/bin/env python3
2
3import importlib4import logging5import os6import re7import subprocess8import sys9import warnings10
11
12try:13from .common import BenchmarkRunner, download_retry_decorator, main14except ImportError:15from common import BenchmarkRunner, download_retry_decorator, main16
17import torch18from torch._dynamo.testing import collect_results, reduce_to_scalar_loss19from torch._dynamo.utils import clone_inputs20
21
22# Enable FX graph caching
23if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:24torch._inductor.config.fx_graph_cache = True25
26
27def pip_install(package):28subprocess.check_call([sys.executable, "-m", "pip", "install", package])29
30
31try:32importlib.import_module("timm")33except ModuleNotFoundError:34print("Installing PyTorch Image Models...")35pip_install("git+https://github.com/rwightman/pytorch-image-models")36finally:37from timm import __version__ as timmversion38from timm.data import resolve_data_config39from timm.models import create_model40
41TIMM_MODELS = {}42filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt")43
44with open(filename) as fh:45lines = fh.readlines()46lines = [line.rstrip() for line in lines]47for line in lines:48model_name, batch_size = line.split(" ")49TIMM_MODELS[model_name] = int(batch_size)50
51
52# TODO - Figure out the reason of cold start memory spike
53
54BATCH_SIZE_DIVISORS = {55"beit_base_patch16_224": 2,56"convit_base": 2,57"convmixer_768_32": 2,58"convnext_base": 2,59"cspdarknet53": 2,60"deit_base_distilled_patch16_224": 2,61"gluon_xception65": 2,62"mobilevit_s": 2,63"pnasnet5large": 2,64"poolformer_m36": 2,65"resnest101e": 2,66"swin_base_patch4_window7_224": 2,67"swsl_resnext101_32x16d": 2,68"vit_base_patch16_224": 2,69"volo_d1_224": 2,70"jx_nest_base": 4,71}
72
73REQUIRE_HIGHER_TOLERANCE = {74"fbnetv3_b",75"gmixer_24_224",76"hrnet_w18",77"inception_v3",78"mixer_b16_224",79"mobilenetv3_large_100",80"sebotnet33ts_256",81"selecsls42b",82}
83
84REQUIRE_EVEN_HIGHER_TOLERANCE = {85"levit_128",86"sebotnet33ts_256",87"beit_base_patch16_224",88"cspdarknet53",89}
90
91# These models need higher tolerance in MaxAutotune mode
92REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE = {93"gluon_inception_v3",94}
95
96REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING = {97"adv_inception_v3",98"botnet26t_256",99"gluon_inception_v3",100"selecsls42b",101"swsl_resnext101_32x16d",102}
103
104SCALED_COMPUTE_LOSS = {105"ese_vovnet19b_dw",106"fbnetc_100",107"mnasnet_100",108"mobilevit_s",109"sebotnet33ts_256",110}
111
112FORCE_AMP_FOR_FP16_BF16_MODELS = {113"convit_base",114"xcit_large_24_p8_224",115}
116
117SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS = {118"xcit_large_24_p8_224",119}
120
121REQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR = {122"inception_v3",123"mobilenetv3_large_100",124"cspdarknet53",125}
126
127
128def refresh_model_names():129import glob130
131from timm.models import list_models132
133def read_models_from_docs():134models = set()135# TODO - set the path to pytorch-image-models repo136for fn in glob.glob("../pytorch-image-models/docs/models/*.md"):137with open(fn) as f:138while True:139line = f.readline()140if not line:141break142if not line.startswith("model = timm.create_model("):143continue144
145model = line.split("'")[1]146# print(model)147models.add(model)148return models149
150def get_family_name(name):151known_families = [152"darknet",153"densenet",154"dla",155"dpn",156"ecaresnet",157"halo",158"regnet",159"efficientnet",160"deit",161"mobilevit",162"mnasnet",163"convnext",164"resnet",165"resnest",166"resnext",167"selecsls",168"vgg",169"xception",170]171
172for known_family in known_families:173if known_family in name:174return known_family175
176if name.startswith("gluon_"):177return "gluon_" + name.split("_")[1]178return name.split("_")[0]179
180def populate_family(models):181family = {}182for model_name in models:183family_name = get_family_name(model_name)184if family_name not in family:185family[family_name] = []186family[family_name].append(model_name)187return family188
189docs_models = read_models_from_docs()190all_models = list_models(pretrained=True, exclude_filters=["*in21k"])191
192all_models_family = populate_family(all_models)193docs_models_family = populate_family(docs_models)194
195for key in docs_models_family:196del all_models_family[key]197
198chosen_models = set()199chosen_models.update(value[0] for value in docs_models_family.values())200
201chosen_models.update(value[0] for key, value in all_models_family.items())202
203filename = "timm_models_list.txt"204if os.path.exists("benchmarks"):205filename = "benchmarks/" + filename206with open(filename, "w") as fw:207for model_name in sorted(chosen_models):208fw.write(model_name + "\n")209
210
211class TimmRunner(BenchmarkRunner):212def __init__(self):213super().__init__()214self.suite_name = "timm_models"215
216@property217def force_amp_for_fp16_bf16_models(self):218return FORCE_AMP_FOR_FP16_BF16_MODELS219
220@property221def force_fp16_for_bf16_models(self):222return set()223
224@property225def get_output_amp_train_process_func(self):226return {}227
228@property229def skip_accuracy_check_as_eager_non_deterministic(self):230if self.args.accuracy and self.args.training:231return SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS232return set()233
234@property235def guard_on_nn_module_models(self):236return {237"convit_base",238}239
240@property241def inline_inbuilt_nn_modules_models(self):242return {243"lcnet_050",244}245
246@download_retry_decorator247def _download_model(self, model_name):248model = create_model(249model_name,250in_chans=3,251scriptable=False,252num_classes=None,253drop_rate=0.0,254drop_path_rate=None,255drop_block_rate=None,256pretrained=True,257)258return model259
260def load_model(261self,262device,263model_name,264batch_size=None,265extra_args=None,266):267if self.args.enable_activation_checkpointing:268raise NotImplementedError(269"Activation checkpointing not implemented for Timm models"270)271
272is_training = self.args.training273use_eval_mode = self.args.use_eval_mode274
275channels_last = self._args.channels_last276model = self._download_model(model_name)277
278if model is None:279raise RuntimeError(f"Failed to load model '{model_name}'")280model.to(281device=device,282memory_format=torch.channels_last if channels_last else None,283)284
285self.num_classes = model.num_classes286
287data_config = resolve_data_config(288vars(self._args) if timmversion >= "0.8.0" else self._args,289model=model,290use_test_size=not is_training,291)292input_size = data_config["input_size"]293recorded_batch_size = TIMM_MODELS[model_name]294
295if model_name in BATCH_SIZE_DIVISORS:296recorded_batch_size = max(297int(recorded_batch_size / BATCH_SIZE_DIVISORS[model_name]), 1298)299batch_size = batch_size or recorded_batch_size300
301torch.manual_seed(1337)302input_tensor = torch.randint(303256, size=(batch_size,) + input_size, device=device304).to(dtype=torch.float32)305mean = torch.mean(input_tensor)306std_dev = torch.std(input_tensor)307example_inputs = (input_tensor - mean) / std_dev308
309if channels_last:310example_inputs = example_inputs.contiguous(311memory_format=torch.channels_last312)313example_inputs = [314example_inputs,315]316self.target = self._gen_target(batch_size, device)317
318self.loss = torch.nn.CrossEntropyLoss().to(device)319
320if model_name in SCALED_COMPUTE_LOSS:321self.compute_loss = self.scaled_compute_loss322
323if is_training and not use_eval_mode:324model.train()325else:326model.eval()327
328self.validate_model(model, example_inputs)329
330return device, model_name, model, example_inputs, batch_size331
332def iter_model_names(self, args):333# for model_name in list_models(pretrained=True, exclude_filters=["*in21k"]):334model_names = sorted(TIMM_MODELS.keys())335start, end = self.get_benchmark_indices(len(model_names))336for index, model_name in enumerate(model_names):337if index < start or index >= end:338continue339if (340not re.search("|".join(args.filter), model_name, re.IGNORECASE)341or re.search("|".join(args.exclude), model_name, re.IGNORECASE)342or model_name in args.exclude_exact343or model_name in self.skip_models344):345continue346
347yield model_name348
349def pick_grad(self, name, is_training):350if is_training:351return torch.enable_grad()352else:353return torch.no_grad()354
355def use_larger_multiplier_for_smaller_tensor(self, name):356return name in REQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR357
358def get_tolerance_and_cosine_flag(self, is_training, current_device, name):359cosine = self.args.cosine360tolerance = 1e-3361
362if self.args.freezing and name in REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING:363# the conv-batchnorm fusion used under freezing may cause relatively364# large numerical difference. We need are larger tolerance.365# Check https://github.com/pytorch/pytorch/issues/120545 for context366tolerance = 8 * 1e-2367
368if is_training:369from torch._inductor import config as inductor_config370
371if name in REQUIRE_EVEN_HIGHER_TOLERANCE or (372inductor_config.max_autotune373and name in REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE374):375tolerance = 8 * 1e-2376elif name in REQUIRE_HIGHER_TOLERANCE:377tolerance = 4 * 1e-2378else:379tolerance = 1e-2380return tolerance, cosine381
382def _gen_target(self, batch_size, device):383return torch.empty((batch_size,) + (), device=device, dtype=torch.long).random_(384self.num_classes385)386
387def compute_loss(self, pred):388# High loss values make gradient checking harder, as small changes in389# accumulation order upsets accuracy checks.390return reduce_to_scalar_loss(pred)391
392def scaled_compute_loss(self, pred):393# Loss values need zoom out further.394return reduce_to_scalar_loss(pred) / 1000.0395
396def forward_pass(self, mod, inputs, collect_outputs=True):397with self.autocast(**self.autocast_arg):398return mod(*inputs)399
400def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):401cloned_inputs = clone_inputs(inputs)402self.optimizer_zero_grad(mod)403with self.autocast(**self.autocast_arg):404pred = mod(*cloned_inputs)405if isinstance(pred, tuple):406pred = pred[0]407loss = self.compute_loss(pred)408self.grad_scaler.scale(loss).backward()409self.optimizer_step()410if collect_outputs:411return collect_results(mod, pred, loss, cloned_inputs)412return None413
414
415def timm_main():416logging.basicConfig(level=logging.WARNING)417warnings.filterwarnings("ignore")418main(TimmRunner())419
420
421if __name__ == "__main__":422timm_main()423