pytorch

Форк
0
/
timm_models.py 
422 строки · 11.9 Кб
1
#!/usr/bin/env python3
2

3
import importlib
4
import logging
5
import os
6
import re
7
import subprocess
8
import sys
9
import warnings
10

11

12
try:
13
    from .common import BenchmarkRunner, download_retry_decorator, main
14
except ImportError:
15
    from common import BenchmarkRunner, download_retry_decorator, main
16

17
import torch
18
from torch._dynamo.testing import collect_results, reduce_to_scalar_loss
19
from torch._dynamo.utils import clone_inputs
20

21

22
# Enable FX graph caching
23
if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:
24
    torch._inductor.config.fx_graph_cache = True
25

26

27
def pip_install(package):
28
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])
29

30

31
try:
32
    importlib.import_module("timm")
33
except ModuleNotFoundError:
34
    print("Installing PyTorch Image Models...")
35
    pip_install("git+https://github.com/rwightman/pytorch-image-models")
36
finally:
37
    from timm import __version__ as timmversion
38
    from timm.data import resolve_data_config
39
    from timm.models import create_model
40

41
TIMM_MODELS = {}
42
filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt")
43

44
with open(filename) as fh:
45
    lines = fh.readlines()
46
    lines = [line.rstrip() for line in lines]
47
    for line in lines:
48
        model_name, batch_size = line.split(" ")
49
        TIMM_MODELS[model_name] = int(batch_size)
50

51

52
# TODO - Figure out the reason of cold start memory spike
53

54
BATCH_SIZE_DIVISORS = {
55
    "beit_base_patch16_224": 2,
56
    "convit_base": 2,
57
    "convmixer_768_32": 2,
58
    "convnext_base": 2,
59
    "cspdarknet53": 2,
60
    "deit_base_distilled_patch16_224": 2,
61
    "gluon_xception65": 2,
62
    "mobilevit_s": 2,
63
    "pnasnet5large": 2,
64
    "poolformer_m36": 2,
65
    "resnest101e": 2,
66
    "swin_base_patch4_window7_224": 2,
67
    "swsl_resnext101_32x16d": 2,
68
    "vit_base_patch16_224": 2,
69
    "volo_d1_224": 2,
70
    "jx_nest_base": 4,
71
}
72

73
REQUIRE_HIGHER_TOLERANCE = {
74
    "fbnetv3_b",
75
    "gmixer_24_224",
76
    "hrnet_w18",
77
    "inception_v3",
78
    "mixer_b16_224",
79
    "mobilenetv3_large_100",
80
    "sebotnet33ts_256",
81
    "selecsls42b",
82
}
83

84
REQUIRE_EVEN_HIGHER_TOLERANCE = {
85
    "levit_128",
86
    "sebotnet33ts_256",
87
    "beit_base_patch16_224",
88
    "cspdarknet53",
89
}
90

91
# These models need higher tolerance in MaxAutotune mode
92
REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE = {
93
    "gluon_inception_v3",
94
}
95

96
REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING = {
97
    "adv_inception_v3",
98
    "botnet26t_256",
99
    "gluon_inception_v3",
100
    "selecsls42b",
101
    "swsl_resnext101_32x16d",
102
}
103

104
SCALED_COMPUTE_LOSS = {
105
    "ese_vovnet19b_dw",
106
    "fbnetc_100",
107
    "mnasnet_100",
108
    "mobilevit_s",
109
    "sebotnet33ts_256",
110
}
111

112
FORCE_AMP_FOR_FP16_BF16_MODELS = {
113
    "convit_base",
114
    "xcit_large_24_p8_224",
115
}
116

117
SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS = {
118
    "xcit_large_24_p8_224",
119
}
120

121
REQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR = {
122
    "inception_v3",
123
    "mobilenetv3_large_100",
124
    "cspdarknet53",
125
}
126

127

128
def refresh_model_names():
129
    import glob
130

131
    from timm.models import list_models
132

133
    def read_models_from_docs():
134
        models = set()
135
        # TODO - set the path to pytorch-image-models repo
136
        for fn in glob.glob("../pytorch-image-models/docs/models/*.md"):
137
            with open(fn) as f:
138
                while True:
139
                    line = f.readline()
140
                    if not line:
141
                        break
142
                    if not line.startswith("model = timm.create_model("):
143
                        continue
144

145
                    model = line.split("'")[1]
146
                    # print(model)
147
                    models.add(model)
148
        return models
149

150
    def get_family_name(name):
151
        known_families = [
152
            "darknet",
153
            "densenet",
154
            "dla",
155
            "dpn",
156
            "ecaresnet",
157
            "halo",
158
            "regnet",
159
            "efficientnet",
160
            "deit",
161
            "mobilevit",
162
            "mnasnet",
163
            "convnext",
164
            "resnet",
165
            "resnest",
166
            "resnext",
167
            "selecsls",
168
            "vgg",
169
            "xception",
170
        ]
171

172
        for known_family in known_families:
173
            if known_family in name:
174
                return known_family
175

176
        if name.startswith("gluon_"):
177
            return "gluon_" + name.split("_")[1]
178
        return name.split("_")[0]
179

180
    def populate_family(models):
181
        family = {}
182
        for model_name in models:
183
            family_name = get_family_name(model_name)
184
            if family_name not in family:
185
                family[family_name] = []
186
            family[family_name].append(model_name)
187
        return family
188

189
    docs_models = read_models_from_docs()
190
    all_models = list_models(pretrained=True, exclude_filters=["*in21k"])
191

192
    all_models_family = populate_family(all_models)
193
    docs_models_family = populate_family(docs_models)
194

195
    for key in docs_models_family:
196
        del all_models_family[key]
197

198
    chosen_models = set()
199
    chosen_models.update(value[0] for value in docs_models_family.values())
200

201
    chosen_models.update(value[0] for key, value in all_models_family.items())
202

203
    filename = "timm_models_list.txt"
204
    if os.path.exists("benchmarks"):
205
        filename = "benchmarks/" + filename
206
    with open(filename, "w") as fw:
207
        for model_name in sorted(chosen_models):
208
            fw.write(model_name + "\n")
209

210

211
class TimmRunner(BenchmarkRunner):
212
    def __init__(self):
213
        super().__init__()
214
        self.suite_name = "timm_models"
215

216
    @property
217
    def force_amp_for_fp16_bf16_models(self):
218
        return FORCE_AMP_FOR_FP16_BF16_MODELS
219

220
    @property
221
    def force_fp16_for_bf16_models(self):
222
        return set()
223

224
    @property
225
    def get_output_amp_train_process_func(self):
226
        return {}
227

228
    @property
229
    def skip_accuracy_check_as_eager_non_deterministic(self):
230
        if self.args.accuracy and self.args.training:
231
            return SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS
232
        return set()
233

234
    @property
235
    def guard_on_nn_module_models(self):
236
        return {
237
            "convit_base",
238
        }
239

240
    @property
241
    def inline_inbuilt_nn_modules_models(self):
242
        return {
243
            "lcnet_050",
244
        }
245

246
    @download_retry_decorator
247
    def _download_model(self, model_name):
248
        model = create_model(
249
            model_name,
250
            in_chans=3,
251
            scriptable=False,
252
            num_classes=None,
253
            drop_rate=0.0,
254
            drop_path_rate=None,
255
            drop_block_rate=None,
256
            pretrained=True,
257
        )
258
        return model
259

260
    def load_model(
261
        self,
262
        device,
263
        model_name,
264
        batch_size=None,
265
        extra_args=None,
266
    ):
267
        if self.args.enable_activation_checkpointing:
268
            raise NotImplementedError(
269
                "Activation checkpointing not implemented for Timm models"
270
            )
271

272
        is_training = self.args.training
273
        use_eval_mode = self.args.use_eval_mode
274

275
        channels_last = self._args.channels_last
276
        model = self._download_model(model_name)
277

278
        if model is None:
279
            raise RuntimeError(f"Failed to load model '{model_name}'")
280
        model.to(
281
            device=device,
282
            memory_format=torch.channels_last if channels_last else None,
283
        )
284

285
        self.num_classes = model.num_classes
286

287
        data_config = resolve_data_config(
288
            vars(self._args) if timmversion >= "0.8.0" else self._args,
289
            model=model,
290
            use_test_size=not is_training,
291
        )
292
        input_size = data_config["input_size"]
293
        recorded_batch_size = TIMM_MODELS[model_name]
294

295
        if model_name in BATCH_SIZE_DIVISORS:
296
            recorded_batch_size = max(
297
                int(recorded_batch_size / BATCH_SIZE_DIVISORS[model_name]), 1
298
            )
299
        batch_size = batch_size or recorded_batch_size
300

301
        torch.manual_seed(1337)
302
        input_tensor = torch.randint(
303
            256, size=(batch_size,) + input_size, device=device
304
        ).to(dtype=torch.float32)
305
        mean = torch.mean(input_tensor)
306
        std_dev = torch.std(input_tensor)
307
        example_inputs = (input_tensor - mean) / std_dev
308

309
        if channels_last:
310
            example_inputs = example_inputs.contiguous(
311
                memory_format=torch.channels_last
312
            )
313
        example_inputs = [
314
            example_inputs,
315
        ]
316
        self.target = self._gen_target(batch_size, device)
317

318
        self.loss = torch.nn.CrossEntropyLoss().to(device)
319

320
        if model_name in SCALED_COMPUTE_LOSS:
321
            self.compute_loss = self.scaled_compute_loss
322

323
        if is_training and not use_eval_mode:
324
            model.train()
325
        else:
326
            model.eval()
327

328
        self.validate_model(model, example_inputs)
329

330
        return device, model_name, model, example_inputs, batch_size
331

332
    def iter_model_names(self, args):
333
        # for model_name in list_models(pretrained=True, exclude_filters=["*in21k"]):
334
        model_names = sorted(TIMM_MODELS.keys())
335
        start, end = self.get_benchmark_indices(len(model_names))
336
        for index, model_name in enumerate(model_names):
337
            if index < start or index >= end:
338
                continue
339
            if (
340
                not re.search("|".join(args.filter), model_name, re.IGNORECASE)
341
                or re.search("|".join(args.exclude), model_name, re.IGNORECASE)
342
                or model_name in args.exclude_exact
343
                or model_name in self.skip_models
344
            ):
345
                continue
346

347
            yield model_name
348

349
    def pick_grad(self, name, is_training):
350
        if is_training:
351
            return torch.enable_grad()
352
        else:
353
            return torch.no_grad()
354

355
    def use_larger_multiplier_for_smaller_tensor(self, name):
356
        return name in REQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR
357

358
    def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
359
        cosine = self.args.cosine
360
        tolerance = 1e-3
361

362
        if self.args.freezing and name in REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING:
363
            # the conv-batchnorm fusion used under freezing may cause relatively
364
            # large numerical difference. We need are larger tolerance.
365
            # Check https://github.com/pytorch/pytorch/issues/120545 for context
366
            tolerance = 8 * 1e-2
367

368
        if is_training:
369
            from torch._inductor import config as inductor_config
370

371
            if name in REQUIRE_EVEN_HIGHER_TOLERANCE or (
372
                inductor_config.max_autotune
373
                and name in REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE
374
            ):
375
                tolerance = 8 * 1e-2
376
            elif name in REQUIRE_HIGHER_TOLERANCE:
377
                tolerance = 4 * 1e-2
378
            else:
379
                tolerance = 1e-2
380
        return tolerance, cosine
381

382
    def _gen_target(self, batch_size, device):
383
        return torch.empty((batch_size,) + (), device=device, dtype=torch.long).random_(
384
            self.num_classes
385
        )
386

387
    def compute_loss(self, pred):
388
        # High loss values make gradient checking harder, as small changes in
389
        # accumulation order upsets accuracy checks.
390
        return reduce_to_scalar_loss(pred)
391

392
    def scaled_compute_loss(self, pred):
393
        # Loss values need zoom out further.
394
        return reduce_to_scalar_loss(pred) / 1000.0
395

396
    def forward_pass(self, mod, inputs, collect_outputs=True):
397
        with self.autocast(**self.autocast_arg):
398
            return mod(*inputs)
399

400
    def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
401
        cloned_inputs = clone_inputs(inputs)
402
        self.optimizer_zero_grad(mod)
403
        with self.autocast(**self.autocast_arg):
404
            pred = mod(*cloned_inputs)
405
            if isinstance(pred, tuple):
406
                pred = pred[0]
407
            loss = self.compute_loss(pred)
408
        self.grad_scaler.scale(loss).backward()
409
        self.optimizer_step()
410
        if collect_outputs:
411
            return collect_results(mod, pred, loss, cloned_inputs)
412
        return None
413

414

415
def timm_main():
416
    logging.basicConfig(level=logging.WARNING)
417
    warnings.filterwarnings("ignore")
418
    main(TimmRunner())
419

420

421
if __name__ == "__main__":
422
    timm_main()
423

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.