pytorch-image-models

Форк
0
1179 строк · 53.4 Кб
1
#!/usr/bin/env python3
2
""" ImageNet Training Script
3

4
This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet
5
training results with some of the latest networks and training techniques. It favours canonical PyTorch
6
and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed
7
and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.
8

9
This script was started from an early version of the PyTorch ImageNet example
10
(https://github.com/pytorch/examples/tree/master/imagenet)
11

12
NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
13
(https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
14

15
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
16
"""
17
import argparse
18
import importlib
19
import json
20
import logging
21
import os
22
import time
23
from collections import OrderedDict
24
from contextlib import suppress
25
from datetime import datetime
26
from functools import partial
27

28
import torch
29
import torch.nn as nn
30
import torchvision.utils
31
import yaml
32
from torch.nn.parallel import DistributedDataParallel as NativeDDP
33

34
from timm import utils
35
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
36
from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm
37
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy
38
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters
39
from timm.optim import create_optimizer_v2, optimizer_kwargs
40
from timm.scheduler import create_scheduler_v2, scheduler_kwargs
41
from timm.utils import ApexScaler, NativeScaler
42

43
try:
44
    from apex import amp
45
    from apex.parallel import DistributedDataParallel as ApexDDP
46
    from apex.parallel import convert_syncbn_model
47
    has_apex = True
48
except ImportError:
49
    has_apex = False
50

51
has_native_amp = False
52
try:
53
    if getattr(torch.cuda.amp, 'autocast') is not None:
54
        has_native_amp = True
55
except AttributeError:
56
    pass
57

58
try:
59
    import wandb
60
    has_wandb = True
61
except ImportError:
62
    has_wandb = False
63

64
try:
65
    from functorch.compile import memory_efficient_fusion
66
    has_functorch = True
67
except ImportError as e:
68
    has_functorch = False
69

70
has_compile = hasattr(torch, 'compile')
71

72

73
_logger = logging.getLogger('train')
74

75
# The first arg parser parses out only the --config argument, this argument is used to
76
# load a yaml file containing key-values that override the defaults for the main parser below
77
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
78
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
79
                    help='YAML config file specifying default arguments')
80

81

82
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
83

84
# Dataset parameters
85
group = parser.add_argument_group('Dataset parameters')
86
# Keep this argument outside the dataset group because it is positional.
87
parser.add_argument('data', nargs='?', metavar='DIR', const=None,
88
                    help='path to dataset (positional is *deprecated*, use --data-dir)')
89
parser.add_argument('--data-dir', metavar='DIR',
90
                    help='path to dataset (root dir)')
91
parser.add_argument('--dataset', metavar='NAME', default='',
92
                    help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
93
group.add_argument('--train-split', metavar='NAME', default='train',
94
                   help='dataset train split (default: train)')
95
group.add_argument('--val-split', metavar='NAME', default='validation',
96
                   help='dataset validation split (default: validation)')
97
parser.add_argument('--train-num-samples', default=None, type=int,
98
                    metavar='N', help='Manually specify num samples in train split, for IterableDatasets.')
99
parser.add_argument('--val-num-samples', default=None, type=int,
100
                    metavar='N', help='Manually specify num samples in validation split, for IterableDatasets.')
101
group.add_argument('--dataset-download', action='store_true', default=False,
102
                   help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
103
group.add_argument('--class-map', default='', type=str, metavar='FILENAME',
104
                   help='path to class to idx mapping file (default: "")')
105
group.add_argument('--input-img-mode', default=None, type=str,
106
                   help='Dataset image conversion mode for input images.')
107
group.add_argument('--input-key', default=None, type=str,
108
                   help='Dataset key for input images.')
109
group.add_argument('--target-key', default=None, type=str,
110
                   help='Dataset key for target labels.')
111

112
# Model parameters
113
group = parser.add_argument_group('Model parameters')
114
group.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
115
                   help='Name of model to train (default: "resnet50")')
116
group.add_argument('--pretrained', action='store_true', default=False,
117
                   help='Start with pretrained version of specified network (if avail)')
118
group.add_argument('--pretrained-path', default=None, type=str,
119
                   help='Load this checkpoint as if they were the pretrained weights (with adaptation).')
120
group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
121
                   help='Load this checkpoint into model after initialization (default: none)')
122
group.add_argument('--resume', default='', type=str, metavar='PATH',
123
                   help='Resume full model and optimizer state from checkpoint (default: none)')
124
group.add_argument('--no-resume-opt', action='store_true', default=False,
125
                   help='prevent resume of optimizer state when resuming model')
126
group.add_argument('--num-classes', type=int, default=None, metavar='N',
127
                   help='number of label classes (Model default if None)')
128
group.add_argument('--gp', default=None, type=str, metavar='POOL',
129
                   help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
130
group.add_argument('--img-size', type=int, default=None, metavar='N',
131
                   help='Image size (default: None => model default)')
132
group.add_argument('--in-chans', type=int, default=None, metavar='N',
133
                   help='Image input channels (default: None => 3)')
134
group.add_argument('--input-size', default=None, nargs=3, type=int,
135
                   metavar='N N N',
136
                   help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
137
group.add_argument('--crop-pct', default=None, type=float,
138
                   metavar='N', help='Input image center crop percent (for validation only)')
139
group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
140
                   help='Override mean pixel value of dataset')
141
group.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
142
                   help='Override std deviation of dataset')
143
group.add_argument('--interpolation', default='', type=str, metavar='NAME',
144
                   help='Image resize interpolation type (overrides model)')
145
group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
146
                   help='Input batch size for training (default: 128)')
147
group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
148
                   help='Validation batch size override (default: None)')
149
group.add_argument('--channels-last', action='store_true', default=False,
150
                   help='Use channels_last memory layout')
151
group.add_argument('--fuser', default='', type=str,
152
                   help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
153
group.add_argument('--grad-accum-steps', type=int, default=1, metavar='N',
154
                   help='The number of steps to accumulate gradients (default: 1)')
155
group.add_argument('--grad-checkpointing', action='store_true', default=False,
156
                   help='Enable gradient checkpointing through model blocks/stages')
157
group.add_argument('--fast-norm', default=False, action='store_true',
158
                   help='enable experimental fast-norm')
159
group.add_argument('--model-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
160
group.add_argument('--head-init-scale', default=None, type=float,
161
                   help='Head initialization scale')
162
group.add_argument('--head-init-bias', default=None, type=float,
163
                   help='Head initialization bias value')
164

165
# scripting / codegen
166
scripting_group = group.add_mutually_exclusive_group()
167
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
168
                             help='torch.jit.script the full model')
169
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
170
                             help="Enable compilation w/ specified backend (default: inductor).")
171

172
# Device & distributed
173
group = parser.add_argument_group('Device parameters')
174
group.add_argument('--device', default='cuda', type=str,
175
                    help="Device (accelerator) to use.")
176
group.add_argument('--amp', action='store_true', default=False,
177
                   help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
178
group.add_argument('--amp-dtype', default='float16', type=str,
179
                   help='lower precision AMP dtype (default: float16)')
180
group.add_argument('--amp-impl', default='native', type=str,
181
                   help='AMP impl to use, "native" or "apex" (default: native)')
182
group.add_argument('--no-ddp-bb', action='store_true', default=False,
183
                   help='Force broadcast buffers for native DDP to off.')
184
group.add_argument('--synchronize-step', action='store_true', default=False,
185
                   help='torch.cuda.synchronize() end of each step')
186
group.add_argument("--local_rank", default=0, type=int)
187
parser.add_argument('--device-modules', default=None, type=str, nargs='+',
188
                    help="Python imports for device backend modules.")
189

190
# Optimizer parameters
191
group = parser.add_argument_group('Optimizer parameters')
192
group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
193
                   help='Optimizer (default: "sgd")')
194
group.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
195
                   help='Optimizer Epsilon (default: None, use opt default)')
196
group.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
197
                   help='Optimizer Betas (default: None, use opt default)')
198
group.add_argument('--momentum', type=float, default=0.9, metavar='M',
199
                   help='Optimizer momentum (default: 0.9)')
200
group.add_argument('--weight-decay', type=float, default=2e-5,
201
                   help='weight decay (default: 2e-5)')
202
group.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
203
                   help='Clip gradient norm (default: None, no clipping)')
204
group.add_argument('--clip-mode', type=str, default='norm',
205
                   help='Gradient clipping mode. One of ("norm", "value", "agc")')
206
group.add_argument('--layer-decay', type=float, default=None,
207
                   help='layer-wise learning rate decay (default: None)')
208
group.add_argument('--opt-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
209

210
# Learning rate schedule parameters
211
group = parser.add_argument_group('Learning rate schedule parameters')
212
group.add_argument('--sched', type=str, default='cosine', metavar='SCHEDULER',
213
                   help='LR scheduler (default: "step"')
214
group.add_argument('--sched-on-updates', action='store_true', default=False,
215
                   help='Apply LR scheduler step on update instead of epoch end.')
216
group.add_argument('--lr', type=float, default=None, metavar='LR',
217
                   help='learning rate, overrides lr-base if set (default: None)')
218
group.add_argument('--lr-base', type=float, default=0.1, metavar='LR',
219
                   help='base learning rate: lr = lr_base * global_batch_size / base_size')
220
group.add_argument('--lr-base-size', type=int, default=256, metavar='DIV',
221
                   help='base learning rate batch size (divisor, default: 256).')
222
group.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE',
223
                   help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)')
224
group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
225
                   help='learning rate noise on/off epoch percentages')
226
group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
227
                   help='learning rate noise limit percent (default: 0.67)')
228
group.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
229
                   help='learning rate noise std-dev (default: 1.0)')
230
group.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
231
                   help='learning rate cycle len multiplier (default: 1.0)')
232
group.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',
233
                   help='amount to decay each learning rate cycle (default: 0.5)')
234
group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
235
                   help='learning rate cycle limit, cycles enabled if > 1')
236
group.add_argument('--lr-k-decay', type=float, default=1.0,
237
                   help='learning rate k-decay for cosine/poly (default: 1.0)')
238
group.add_argument('--warmup-lr', type=float, default=1e-5, metavar='LR',
239
                   help='warmup learning rate (default: 1e-5)')
240
group.add_argument('--min-lr', type=float, default=0, metavar='LR',
241
                   help='lower lr bound for cyclic schedulers that hit 0 (default: 0)')
242
group.add_argument('--epochs', type=int, default=300, metavar='N',
243
                   help='number of epochs to train (default: 300)')
244
group.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
245
                   help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
246
group.add_argument('--start-epoch', default=None, type=int, metavar='N',
247
                   help='manual epoch number (useful on restarts)')
248
group.add_argument('--decay-milestones', default=[90, 180, 270], type=int, nargs='+', metavar="MILESTONES",
249
                   help='list of decay epoch indices for multistep lr. must be increasing')
250
group.add_argument('--decay-epochs', type=float, default=90, metavar='N',
251
                   help='epoch interval to decay LR')
252
group.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
253
                   help='epochs to warmup LR, if scheduler supports')
254
group.add_argument('--warmup-prefix', action='store_true', default=False,
255
                   help='Exclude warmup period from decay schedule.'),
256
group.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',
257
                   help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
258
group.add_argument('--patience-epochs', type=int, default=10, metavar='N',
259
                   help='patience epochs for Plateau LR scheduler (default: 10)')
260
group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
261
                   help='LR decay rate (default: 0.1)')
262

263
# Augmentation & regularization parameters
264
group = parser.add_argument_group('Augmentation and regularization parameters')
265
group.add_argument('--no-aug', action='store_true', default=False,
266
                   help='Disable all training augmentation, override other train aug args')
267
group.add_argument('--train-crop-mode', type=str, default=None,
268
                   help='Crop-mode in train'),
269
group.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
270
                   help='Random resize scale (default: 0.08 1.0)')
271
group.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
272
                   help='Random resize aspect ratio (default: 0.75 1.33)')
273
group.add_argument('--hflip', type=float, default=0.5,
274
                   help='Horizontal flip training aug probability')
275
group.add_argument('--vflip', type=float, default=0.,
276
                   help='Vertical flip training aug probability')
277
group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
278
                   help='Color jitter factor (default: 0.4)')
279
group.add_argument('--color-jitter-prob', type=float, default=None, metavar='PCT',
280
                   help='Probability of applying any color jitter.')
281
group.add_argument('--grayscale-prob', type=float, default=None, metavar='PCT',
282
                   help='Probability of applying random grayscale conversion.')
283
group.add_argument('--gaussian-blur-prob', type=float, default=None, metavar='PCT',
284
                   help='Probability of applying gaussian blur.')
285
group.add_argument('--aa', type=str, default=None, metavar='NAME',
286
                   help='Use AutoAugment policy. "v0" or "original". (default: None)'),
287
group.add_argument('--aug-repeats', type=float, default=0,
288
                   help='Number of augmentation repetitions (distributed training only) (default: 0)')
289
group.add_argument('--aug-splits', type=int, default=0,
290
                   help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
291
group.add_argument('--jsd-loss', action='store_true', default=False,
292
                   help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
293
group.add_argument('--bce-loss', action='store_true', default=False,
294
                   help='Enable BCE loss w/ Mixup/CutMix use.')
295
group.add_argument('--bce-sum', action='store_true', default=False,
296
                   help='Sum over classes when using BCE loss.')
297
group.add_argument('--bce-target-thresh', type=float, default=None,
298
                   help='Threshold for binarizing softened BCE targets (default: None, disabled).')
299
group.add_argument('--bce-pos-weight', type=float, default=None,
300
                   help='Positive weighting for BCE loss.')
301
group.add_argument('--reprob', type=float, default=0., metavar='PCT',
302
                   help='Random erase prob (default: 0.)')
303
group.add_argument('--remode', type=str, default='pixel',
304
                   help='Random erase mode (default: "pixel")')
305
group.add_argument('--recount', type=int, default=1,
306
                   help='Random erase count (default: 1)')
307
group.add_argument('--resplit', action='store_true', default=False,
308
                   help='Do not random erase first (clean) augmentation split')
309
group.add_argument('--mixup', type=float, default=0.0,
310
                   help='mixup alpha, mixup enabled if > 0. (default: 0.)')
311
group.add_argument('--cutmix', type=float, default=0.0,
312
                   help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
313
group.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
314
                   help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
315
group.add_argument('--mixup-prob', type=float, default=1.0,
316
                   help='Probability of performing mixup or cutmix when either/both is enabled')
317
group.add_argument('--mixup-switch-prob', type=float, default=0.5,
318
                   help='Probability of switching to cutmix when both mixup and cutmix enabled')
319
group.add_argument('--mixup-mode', type=str, default='batch',
320
                   help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
321
group.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
322
                   help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
323
group.add_argument('--smoothing', type=float, default=0.1,
324
                   help='Label smoothing (default: 0.1)')
325
group.add_argument('--train-interpolation', type=str, default='random',
326
                   help='Training interpolation (random, bilinear, bicubic default: "random")')
327
group.add_argument('--drop', type=float, default=0.0, metavar='PCT',
328
                   help='Dropout rate (default: 0.)')
329
group.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
330
                   help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
331
group.add_argument('--drop-path', type=float, default=None, metavar='PCT',
332
                   help='Drop path rate (default: None)')
333
group.add_argument('--drop-block', type=float, default=None, metavar='PCT',
334
                   help='Drop block rate (default: None)')
335

336
# Batch norm parameters (only works with gen_efficientnet based models currently)
337
group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.')
338
group.add_argument('--bn-momentum', type=float, default=None,
339
                   help='BatchNorm momentum override (if not None)')
340
group.add_argument('--bn-eps', type=float, default=None,
341
                   help='BatchNorm epsilon override (if not None)')
342
group.add_argument('--sync-bn', action='store_true',
343
                   help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
344
group.add_argument('--dist-bn', type=str, default='reduce',
345
                   help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
346
group.add_argument('--split-bn', action='store_true',
347
                   help='Enable separate BN layers per augmentation split.')
348

349
# Model Exponential Moving Average
350
group = parser.add_argument_group('Model exponential moving average parameters')
351
group.add_argument('--model-ema', action='store_true', default=False,
352
                   help='Enable tracking moving average of model weights.')
353
group.add_argument('--model-ema-force-cpu', action='store_true', default=False,
354
                   help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
355
group.add_argument('--model-ema-decay', type=float, default=0.9998,
356
                   help='Decay factor for model weights moving average (default: 0.9998)')
357
group.add_argument('--model-ema-warmup', action='store_true',
358
                   help='Enable warmup for model EMA decay.')
359

360
# Misc
361
group = parser.add_argument_group('Miscellaneous parameters')
362
group.add_argument('--seed', type=int, default=42, metavar='S',
363
                   help='random seed (default: 42)')
364
group.add_argument('--worker-seeding', type=str, default='all',
365
                   help='worker seed mode (default: all)')
366
group.add_argument('--log-interval', type=int, default=50, metavar='N',
367
                   help='how many batches to wait before logging training status')
368
group.add_argument('--recovery-interval', type=int, default=0, metavar='N',
369
                   help='how many batches to wait before writing recovery checkpoint')
370
group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',
371
                   help='number of checkpoints to keep (default: 10)')
372
group.add_argument('-j', '--workers', type=int, default=4, metavar='N',
373
                   help='how many training processes to use (default: 4)')
374
group.add_argument('--save-images', action='store_true', default=False,
375
                   help='save images of input bathes every log interval for debugging')
376
group.add_argument('--pin-mem', action='store_true', default=False,
377
                   help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
378
group.add_argument('--no-prefetcher', action='store_true', default=False,
379
                   help='disable fast prefetcher')
380
group.add_argument('--output', default='', type=str, metavar='PATH',
381
                   help='path to output folder (default: none, current dir)')
382
group.add_argument('--experiment', default='', type=str, metavar='NAME',
383
                   help='name of train experiment, name of sub-folder for output')
384
group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
385
                   help='Best metric (default: "top1"')
386
group.add_argument('--tta', type=int, default=0, metavar='N',
387
                   help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
388
group.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
389
                   help='use the multi-epochs-loader to save time at the beginning of every epoch')
390
group.add_argument('--log-wandb', action='store_true', default=False,
391
                   help='log training and validation metrics to wandb')
392

393

394
def _parse_args():
395
    # Do we have a config file to parse?
396
    args_config, remaining = config_parser.parse_known_args()
397
    if args_config.config:
398
        with open(args_config.config, 'r') as f:
399
            cfg = yaml.safe_load(f)
400
            parser.set_defaults(**cfg)
401

402
    # The main arg parser parses the rest of the args, the usual
403
    # defaults will have been overridden if config file specified.
404
    args = parser.parse_args(remaining)
405

406
    # Cache the args as a text string to save them in the output dir later
407
    args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
408
    return args, args_text
409

410

411
def main():
412
    utils.setup_default_logging()
413
    args, args_text = _parse_args()
414

415
    if args.device_modules:
416
        for module in args.device_modules:
417
            importlib.import_module(module)
418

419
    if torch.cuda.is_available():
420
        torch.backends.cuda.matmul.allow_tf32 = True
421
        torch.backends.cudnn.benchmark = True
422

423
    args.prefetcher = not args.no_prefetcher
424
    args.grad_accum_steps = max(1, args.grad_accum_steps)
425
    device = utils.init_distributed_device(args)
426
    if args.distributed:
427
        _logger.info(
428
            'Training in distributed mode with multiple processes, 1 device per process.'
429
            f'Process {args.rank}, total {args.world_size}, device {args.device}.')
430
    else:
431
        _logger.info(f'Training with a single process on 1 device ({args.device}).')
432
    assert args.rank >= 0
433

434
    # resolve AMP arguments based on PyTorch / Apex availability
435
    use_amp = None
436
    amp_dtype = torch.float16
437
    if args.amp:
438
        if args.amp_impl == 'apex':
439
            assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
440
            use_amp = 'apex'
441
            assert args.amp_dtype == 'float16'
442
        else:
443
            assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
444
            use_amp = 'native'
445
            assert args.amp_dtype in ('float16', 'bfloat16')
446
        if args.amp_dtype == 'bfloat16':
447
            amp_dtype = torch.bfloat16
448

449
    utils.random_seed(args.seed, args.rank)
450

451
    if args.fuser:
452
        utils.set_jit_fuser(args.fuser)
453
    if args.fast_norm:
454
        set_fast_norm()
455

456
    in_chans = 3
457
    if args.in_chans is not None:
458
        in_chans = args.in_chans
459
    elif args.input_size is not None:
460
        in_chans = args.input_size[0]
461

462
    factory_kwargs = {}
463
    if args.pretrained_path:
464
        # merge with pretrained_cfg of model, 'file' has priority over 'url' and 'hf_hub'.
465
        factory_kwargs['pretrained_cfg_overlay'] = dict(
466
            file=args.pretrained_path,
467
            num_classes=-1,  # force head adaptation
468
        )
469

470
    model = create_model(
471
        args.model,
472
        pretrained=args.pretrained,
473
        in_chans=in_chans,
474
        num_classes=args.num_classes,
475
        drop_rate=args.drop,
476
        drop_path_rate=args.drop_path,
477
        drop_block_rate=args.drop_block,
478
        global_pool=args.gp,
479
        bn_momentum=args.bn_momentum,
480
        bn_eps=args.bn_eps,
481
        scriptable=args.torchscript,
482
        checkpoint_path=args.initial_checkpoint,
483
        **factory_kwargs,
484
        **args.model_kwargs,
485
    )
486
    if args.head_init_scale is not None:
487
        with torch.no_grad():
488
            model.get_classifier().weight.mul_(args.head_init_scale)
489
            model.get_classifier().bias.mul_(args.head_init_scale)
490
    if args.head_init_bias is not None:
491
        nn.init.constant_(model.get_classifier().bias, args.head_init_bias)
492

493
    if args.num_classes is None:
494
        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
495
        args.num_classes = model.num_classes  # FIXME handle model default vs config num_classes more elegantly
496

497
    if args.grad_checkpointing:
498
        model.set_grad_checkpointing(enable=True)
499

500
    if utils.is_primary(args):
501
        _logger.info(
502
            f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
503

504
    data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args))
505

506
    # setup augmentation batch splits for contrastive loss or split bn
507
    num_aug_splits = 0
508
    if args.aug_splits > 0:
509
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
510
        num_aug_splits = args.aug_splits
511

512
    # enable split bn (separate bn stats per batch-portion)
513
    if args.split_bn:
514
        assert num_aug_splits > 1 or args.resplit
515
        model = convert_splitbn_model(model, max(num_aug_splits, 2))
516

517
    # move model to GPU, enable channels last layout if set
518
    model.to(device=device)
519
    if args.channels_last:
520
        model.to(memory_format=torch.channels_last)
521

522
    # setup synchronized BatchNorm for distributed training
523
    if args.distributed and args.sync_bn:
524
        args.dist_bn = ''  # disable dist_bn when sync BN active
525
        assert not args.split_bn
526
        if has_apex and use_amp == 'apex':
527
            # Apex SyncBN used with Apex AMP
528
            # WARNING this won't currently work with models using BatchNormAct2d
529
            model = convert_syncbn_model(model)
530
        else:
531
            model = convert_sync_batchnorm(model)
532
        if utils.is_primary(args):
533
            _logger.info(
534
                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
535
                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
536

537
    if args.torchscript:
538
        assert not args.torchcompile
539
        assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
540
        assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
541
        model = torch.jit.script(model)
542

543
    if not args.lr:
544
        global_batch_size = args.batch_size * args.world_size * args.grad_accum_steps
545
        batch_ratio = global_batch_size / args.lr_base_size
546
        if not args.lr_base_scale:
547
            on = args.opt.lower()
548
            args.lr_base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear'
549
        if args.lr_base_scale == 'sqrt':
550
            batch_ratio = batch_ratio ** 0.5
551
        args.lr = args.lr_base * batch_ratio
552
        if utils.is_primary(args):
553
            _logger.info(
554
                f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
555
                f'and effective global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')
556

557
    optimizer = create_optimizer_v2(
558
        model,
559
        **optimizer_kwargs(cfg=args),
560
        **args.opt_kwargs,
561
    )
562

563
    # setup automatic mixed-precision (AMP) loss scaling and op casting
564
    amp_autocast = suppress  # do nothing
565
    loss_scaler = None
566
    if use_amp == 'apex':
567
        assert device.type == 'cuda'
568
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
569
        loss_scaler = ApexScaler()
570
        if utils.is_primary(args):
571
            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
572
    elif use_amp == 'native':
573
        try:
574
            amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
575
        except (AttributeError, TypeError):
576
            # fallback to CUDA only AMP for PyTorch < 1.10
577
            assert device.type == 'cuda'
578
            amp_autocast = torch.cuda.amp.autocast
579
        if device.type == 'cuda' and amp_dtype == torch.float16:
580
            # loss scaler only used for float16 (half) dtype, bfloat16 does not need it
581
            loss_scaler = NativeScaler()
582
        if utils.is_primary(args):
583
            _logger.info('Using native Torch AMP. Training in mixed precision.')
584
    else:
585
        if utils.is_primary(args):
586
            _logger.info('AMP not enabled. Training in float32.')
587

588
    # optionally resume from a checkpoint
589
    resume_epoch = None
590
    if args.resume:
591
        resume_epoch = resume_checkpoint(
592
            model,
593
            args.resume,
594
            optimizer=None if args.no_resume_opt else optimizer,
595
            loss_scaler=None if args.no_resume_opt else loss_scaler,
596
            log_info=utils.is_primary(args),
597
        )
598

599
    # setup exponential moving average of model weights, SWA could be used here too
600
    model_ema = None
601
    if args.model_ema:
602
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
603
        model_ema = utils.ModelEmaV3(
604
            model,
605
            decay=args.model_ema_decay,
606
            use_warmup=args.model_ema_warmup,
607
            device='cpu' if args.model_ema_force_cpu else None,
608
        )
609
        if args.resume:
610
            load_checkpoint(model_ema.module, args.resume, use_ema=True)
611
        if args.torchcompile:
612
            model_ema = torch.compile(model_ema, backend=args.torchcompile)
613

614
    # setup distributed training
615
    if args.distributed:
616
        if has_apex and use_amp == 'apex':
617
            # Apex DDP preferred unless native amp is activated
618
            if utils.is_primary(args):
619
                _logger.info("Using NVIDIA APEX DistributedDataParallel.")
620
            model = ApexDDP(model, delay_allreduce=True)
621
        else:
622
            if utils.is_primary(args):
623
                _logger.info("Using native Torch DistributedDataParallel.")
624
            model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb)
625
        # NOTE: EMA model does not need to be wrapped by DDP
626

627
    if args.torchcompile:
628
        # torch compile should be done after DDP
629
        assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
630
        model = torch.compile(model, backend=args.torchcompile)
631

632
    # create the train and eval datasets
633
    if args.data and not args.data_dir:
634
        args.data_dir = args.data
635
    if args.input_img_mode is None:
636
        input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'
637
    else:
638
        input_img_mode = args.input_img_mode
639

640
    dataset_train = create_dataset(
641
        args.dataset,
642
        root=args.data_dir,
643
        split=args.train_split,
644
        is_training=True,
645
        class_map=args.class_map,
646
        download=args.dataset_download,
647
        batch_size=args.batch_size,
648
        seed=args.seed,
649
        repeats=args.epoch_repeats,
650
        input_img_mode=input_img_mode,
651
        input_key=args.input_key,
652
        target_key=args.target_key,
653
        num_samples=args.train_num_samples,
654
    )
655

656
    if args.val_split:
657
        dataset_eval = create_dataset(
658
            args.dataset,
659
            root=args.data_dir,
660
            split=args.val_split,
661
            is_training=False,
662
            class_map=args.class_map,
663
            download=args.dataset_download,
664
            batch_size=args.batch_size,
665
            input_img_mode=input_img_mode,
666
            input_key=args.input_key,
667
            target_key=args.target_key,
668
            num_samples=args.val_num_samples,
669
        )
670

671
    # setup mixup / cutmix
672
    collate_fn = None
673
    mixup_fn = None
674
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
675
    if mixup_active:
676
        mixup_args = dict(
677
            mixup_alpha=args.mixup,
678
            cutmix_alpha=args.cutmix,
679
            cutmix_minmax=args.cutmix_minmax,
680
            prob=args.mixup_prob,
681
            switch_prob=args.mixup_switch_prob,
682
            mode=args.mixup_mode,
683
            label_smoothing=args.smoothing,
684
            num_classes=args.num_classes
685
        )
686
        if args.prefetcher:
687
            assert not num_aug_splits  # collate conflict (need to support de-interleaving in collate mixup)
688
            collate_fn = FastCollateMixup(**mixup_args)
689
        else:
690
            mixup_fn = Mixup(**mixup_args)
691

692
    # wrap dataset in AugMix helper
693
    if num_aug_splits > 1:
694
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
695

696
    # create data loaders w/ augmentation pipeline
697
    train_interpolation = args.train_interpolation
698
    if args.no_aug or not train_interpolation:
699
        train_interpolation = data_config['interpolation']
700
    loader_train = create_loader(
701
        dataset_train,
702
        input_size=data_config['input_size'],
703
        batch_size=args.batch_size,
704
        is_training=True,
705
        no_aug=args.no_aug,
706
        re_prob=args.reprob,
707
        re_mode=args.remode,
708
        re_count=args.recount,
709
        re_split=args.resplit,
710
        train_crop_mode=args.train_crop_mode,
711
        scale=args.scale,
712
        ratio=args.ratio,
713
        hflip=args.hflip,
714
        vflip=args.vflip,
715
        color_jitter=args.color_jitter,
716
        color_jitter_prob=args.color_jitter_prob,
717
        grayscale_prob=args.grayscale_prob,
718
        gaussian_blur_prob=args.gaussian_blur_prob,
719
        auto_augment=args.aa,
720
        num_aug_repeats=args.aug_repeats,
721
        num_aug_splits=num_aug_splits,
722
        interpolation=train_interpolation,
723
        mean=data_config['mean'],
724
        std=data_config['std'],
725
        num_workers=args.workers,
726
        distributed=args.distributed,
727
        collate_fn=collate_fn,
728
        pin_memory=args.pin_mem,
729
        device=device,
730
        use_prefetcher=args.prefetcher,
731
        use_multi_epochs_loader=args.use_multi_epochs_loader,
732
        worker_seeding=args.worker_seeding,
733
    )
734

735
    loader_eval = None
736
    if args.val_split:
737
        eval_workers = args.workers
738
        if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
739
            # FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
740
            eval_workers = min(2, args.workers)
741
        loader_eval = create_loader(
742
            dataset_eval,
743
            input_size=data_config['input_size'],
744
            batch_size=args.validation_batch_size or args.batch_size,
745
            is_training=False,
746
            interpolation=data_config['interpolation'],
747
            mean=data_config['mean'],
748
            std=data_config['std'],
749
            num_workers=eval_workers,
750
            distributed=args.distributed,
751
            crop_pct=data_config['crop_pct'],
752
            pin_memory=args.pin_mem,
753
            device=device,
754
            use_prefetcher=args.prefetcher,
755
        )
756

757
    # setup loss function
758
    if args.jsd_loss:
759
        assert num_aug_splits > 1  # JSD only valid with aug splits set
760
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)
761
    elif mixup_active:
762
        # smoothing is handled with mixup target transform which outputs sparse, soft targets
763
        if args.bce_loss:
764
            train_loss_fn = BinaryCrossEntropy(
765
                target_threshold=args.bce_target_thresh,
766
                sum_classes=args.bce_sum,
767
                pos_weight=args.bce_pos_weight,
768
            )
769
        else:
770
            train_loss_fn = SoftTargetCrossEntropy()
771
    elif args.smoothing:
772
        if args.bce_loss:
773
            train_loss_fn = BinaryCrossEntropy(
774
                smoothing=args.smoothing,
775
                target_threshold=args.bce_target_thresh,
776
                sum_classes=args.bce_sum,
777
                pos_weight=args.bce_pos_weight,
778
            )
779
        else:
780
            train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
781
    else:
782
        train_loss_fn = nn.CrossEntropyLoss()
783
    train_loss_fn = train_loss_fn.to(device=device)
784
    validate_loss_fn = nn.CrossEntropyLoss().to(device=device)
785

786
    # setup checkpoint saver and eval metric tracking
787
    eval_metric = args.eval_metric if loader_eval is not None else 'loss'
788
    decreasing_metric = eval_metric == 'loss'
789
    best_metric = None
790
    best_epoch = None
791
    saver = None
792
    output_dir = None
793
    if utils.is_primary(args):
794
        if args.experiment:
795
            exp_name = args.experiment
796
        else:
797
            exp_name = '-'.join([
798
                datetime.now().strftime("%Y%m%d-%H%M%S"),
799
                safe_model_name(args.model),
800
                str(data_config['input_size'][-1])
801
            ])
802
        output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
803
        saver = utils.CheckpointSaver(
804
            model=model,
805
            optimizer=optimizer,
806
            args=args,
807
            model_ema=model_ema,
808
            amp_scaler=loss_scaler,
809
            checkpoint_dir=output_dir,
810
            recovery_dir=output_dir,
811
            decreasing=decreasing_metric,
812
            max_history=args.checkpoint_hist
813
        )
814
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
815
            f.write(args_text)
816

817
    if utils.is_primary(args) and args.log_wandb:
818
        if has_wandb:
819
            wandb.init(project=args.experiment, config=args)
820
        else:
821
            _logger.warning(
822
                "You've requested to log metrics to wandb but package not found. "
823
                "Metrics not being logged to wandb, try `pip install wandb`")
824

825
    # setup learning rate schedule and starting epoch
826
    updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps
827
    lr_scheduler, num_epochs = create_scheduler_v2(
828
        optimizer,
829
        **scheduler_kwargs(args, decreasing_metric=decreasing_metric),
830
        updates_per_epoch=updates_per_epoch,
831
    )
832
    start_epoch = 0
833
    if args.start_epoch is not None:
834
        # a specified start_epoch will always override the resume epoch
835
        start_epoch = args.start_epoch
836
    elif resume_epoch is not None:
837
        start_epoch = resume_epoch
838
    if lr_scheduler is not None and start_epoch > 0:
839
        if args.sched_on_updates:
840
            lr_scheduler.step_update(start_epoch * updates_per_epoch)
841
        else:
842
            lr_scheduler.step(start_epoch)
843

844
    if utils.is_primary(args):
845
        _logger.info(
846
            f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.')
847

848
    results = []
849
    try:
850
        for epoch in range(start_epoch, num_epochs):
851
            if hasattr(dataset_train, 'set_epoch'):
852
                dataset_train.set_epoch(epoch)
853
            elif args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
854
                loader_train.sampler.set_epoch(epoch)
855

856
            train_metrics = train_one_epoch(
857
                epoch,
858
                model,
859
                loader_train,
860
                optimizer,
861
                train_loss_fn,
862
                args,
863
                lr_scheduler=lr_scheduler,
864
                saver=saver,
865
                output_dir=output_dir,
866
                amp_autocast=amp_autocast,
867
                loss_scaler=loss_scaler,
868
                model_ema=model_ema,
869
                mixup_fn=mixup_fn,
870
                num_updates_total=num_epochs * updates_per_epoch,
871
            )
872

873
            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
874
                if utils.is_primary(args):
875
                    _logger.info("Distributing BatchNorm running means and vars")
876
                utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
877

878
            if loader_eval is not None:
879
                eval_metrics = validate(
880
                    model,
881
                    loader_eval,
882
                    validate_loss_fn,
883
                    args,
884
                    device=device,
885
                    amp_autocast=amp_autocast,
886
                )
887

888
                if model_ema is not None and not args.model_ema_force_cpu:
889
                    if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
890
                        utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
891

892
                    ema_eval_metrics = validate(
893
                        model_ema,
894
                        loader_eval,
895
                        validate_loss_fn,
896
                        args,
897
                        device=device,
898
                        amp_autocast=amp_autocast,
899
                        log_suffix=' (EMA)',
900
                    )
901
                    eval_metrics = ema_eval_metrics
902
            else:
903
                eval_metrics = None
904

905
            if output_dir is not None:
906
                lrs = [param_group['lr'] for param_group in optimizer.param_groups]
907
                utils.update_summary(
908
                    epoch,
909
                    train_metrics,
910
                    eval_metrics,
911
                    filename=os.path.join(output_dir, 'summary.csv'),
912
                    lr=sum(lrs) / len(lrs),
913
                    write_header=best_metric is None,
914
                    log_wandb=args.log_wandb and has_wandb,
915
                )
916

917
            if eval_metrics is not None:
918
                latest_metric = eval_metrics[eval_metric]
919
            else:
920
                latest_metric = train_metrics[eval_metric]
921

922
            if saver is not None:
923
                # save proper checkpoint with eval metric
924
                best_metric, best_epoch = saver.save_checkpoint(epoch, metric=latest_metric)
925

926
            if lr_scheduler is not None:
927
                # step LR for next epoch
928
                lr_scheduler.step(epoch + 1, latest_metric)
929

930
            results.append({
931
                'epoch': epoch,
932
                'train': train_metrics,
933
                'validation': eval_metrics,
934
            })
935

936
    except KeyboardInterrupt:
937
        pass
938

939
    results = {'all': results}
940
    if best_metric is not None:
941
        results['best'] = results['all'][best_epoch - start_epoch]
942
        _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
943
    print(f'--result\n{json.dumps(results, indent=4)}')
944

945

946
def train_one_epoch(
947
        epoch,
948
        model,
949
        loader,
950
        optimizer,
951
        loss_fn,
952
        args,
953
        device=torch.device('cuda'),
954
        lr_scheduler=None,
955
        saver=None,
956
        output_dir=None,
957
        amp_autocast=suppress,
958
        loss_scaler=None,
959
        model_ema=None,
960
        mixup_fn=None,
961
        num_updates_total=None,
962
):
963
    if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
964
        if args.prefetcher and loader.mixup_enabled:
965
            loader.mixup_enabled = False
966
        elif mixup_fn is not None:
967
            mixup_fn.mixup_enabled = False
968

969
    second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
970
    has_no_sync = hasattr(model, "no_sync")
971
    update_time_m = utils.AverageMeter()
972
    data_time_m = utils.AverageMeter()
973
    losses_m = utils.AverageMeter()
974

975
    model.train()
976

977
    accum_steps = args.grad_accum_steps
978
    last_accum_steps = len(loader) % accum_steps
979
    updates_per_epoch = (len(loader) + accum_steps - 1) // accum_steps
980
    num_updates = epoch * updates_per_epoch
981
    last_batch_idx = len(loader) - 1
982
    last_batch_idx_to_accum = len(loader) - last_accum_steps
983

984
    data_start_time = update_start_time = time.time()
985
    optimizer.zero_grad()
986
    update_sample_count = 0
987
    for batch_idx, (input, target) in enumerate(loader):
988
        last_batch = batch_idx == last_batch_idx
989
        need_update = last_batch or (batch_idx + 1) % accum_steps == 0
990
        update_idx = batch_idx // accum_steps
991
        if batch_idx >= last_batch_idx_to_accum:
992
            accum_steps = last_accum_steps
993

994
        if not args.prefetcher:
995
            input, target = input.to(device), target.to(device)
996
            if mixup_fn is not None:
997
                input, target = mixup_fn(input, target)
998
        if args.channels_last:
999
            input = input.contiguous(memory_format=torch.channels_last)
1000

1001
        # multiply by accum steps to get equivalent for full update
1002
        data_time_m.update(accum_steps * (time.time() - data_start_time))
1003

1004
        def _forward():
1005
            with amp_autocast():
1006
                output = model(input)
1007
                loss = loss_fn(output, target)
1008
            if accum_steps > 1:
1009
                loss /= accum_steps
1010
            return loss
1011

1012
        def _backward(_loss):
1013
            if loss_scaler is not None:
1014
                loss_scaler(
1015
                    _loss,
1016
                    optimizer,
1017
                    clip_grad=args.clip_grad,
1018
                    clip_mode=args.clip_mode,
1019
                    parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
1020
                    create_graph=second_order,
1021
                    need_update=need_update,
1022
                )
1023
            else:
1024
                _loss.backward(create_graph=second_order)
1025
                if need_update:
1026
                    if args.clip_grad is not None:
1027
                        utils.dispatch_clip_grad(
1028
                            model_parameters(model, exclude_head='agc' in args.clip_mode),
1029
                            value=args.clip_grad,
1030
                            mode=args.clip_mode,
1031
                        )
1032
                    optimizer.step()
1033

1034
        if has_no_sync and not need_update:
1035
            with model.no_sync():
1036
                loss = _forward()
1037
                _backward(loss)
1038
        else:
1039
            loss = _forward()
1040
            _backward(loss)
1041

1042
        if not args.distributed:
1043
            losses_m.update(loss.item() * accum_steps, input.size(0))
1044
        update_sample_count += input.size(0)
1045

1046
        if not need_update:
1047
            data_start_time = time.time()
1048
            continue
1049

1050
        num_updates += 1
1051
        optimizer.zero_grad()
1052
        if model_ema is not None:
1053
            model_ema.update(model, step=num_updates)
1054

1055
        if args.synchronize_step and device.type == 'cuda':
1056
            torch.cuda.synchronize()
1057
        time_now = time.time()
1058
        update_time_m.update(time.time() - update_start_time)
1059
        update_start_time = time_now
1060

1061
        if update_idx % args.log_interval == 0:
1062
            lrl = [param_group['lr'] for param_group in optimizer.param_groups]
1063
            lr = sum(lrl) / len(lrl)
1064

1065
            if args.distributed:
1066
                reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
1067
                losses_m.update(reduced_loss.item() * accum_steps, input.size(0))
1068
                update_sample_count *= args.world_size
1069

1070
            if utils.is_primary(args):
1071
                _logger.info(
1072
                    f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} '
1073
                    f'({100. * update_idx / (updates_per_epoch - 1):>3.0f}%)]  '
1074
                    f'Loss: {losses_m.val:#.3g} ({losses_m.avg:#.3g})  '
1075
                    f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s  '
1076
                    f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s)  '
1077
                    f'LR: {lr:.3e}  '
1078
                    f'Data: {data_time_m.val:.3f} ({data_time_m.avg:.3f})'
1079
                )
1080

1081
                if args.save_images and output_dir:
1082
                    torchvision.utils.save_image(
1083
                        input,
1084
                        os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
1085
                        padding=0,
1086
                        normalize=True
1087
                    )
1088

1089
        if saver is not None and args.recovery_interval and (
1090
                (update_idx + 1) % args.recovery_interval == 0):
1091
            saver.save_recovery(epoch, batch_idx=update_idx)
1092

1093
        if lr_scheduler is not None:
1094
            lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
1095

1096
        update_sample_count = 0
1097
        data_start_time = time.time()
1098
        # end for
1099

1100
    if hasattr(optimizer, 'sync_lookahead'):
1101
        optimizer.sync_lookahead()
1102

1103
    return OrderedDict([('loss', losses_m.avg)])
1104

1105

1106
def validate(
1107
        model,
1108
        loader,
1109
        loss_fn,
1110
        args,
1111
        device=torch.device('cuda'),
1112
        amp_autocast=suppress,
1113
        log_suffix=''
1114
):
1115
    batch_time_m = utils.AverageMeter()
1116
    losses_m = utils.AverageMeter()
1117
    top1_m = utils.AverageMeter()
1118
    top5_m = utils.AverageMeter()
1119

1120
    model.eval()
1121

1122
    end = time.time()
1123
    last_idx = len(loader) - 1
1124
    with torch.no_grad():
1125
        for batch_idx, (input, target) in enumerate(loader):
1126
            last_batch = batch_idx == last_idx
1127
            if not args.prefetcher:
1128
                input = input.to(device)
1129
                target = target.to(device)
1130
            if args.channels_last:
1131
                input = input.contiguous(memory_format=torch.channels_last)
1132

1133
            with amp_autocast():
1134
                output = model(input)
1135
                if isinstance(output, (tuple, list)):
1136
                    output = output[0]
1137

1138
                # augmentation reduction
1139
                reduce_factor = args.tta
1140
                if reduce_factor > 1:
1141
                    output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
1142
                    target = target[0:target.size(0):reduce_factor]
1143

1144
                loss = loss_fn(output, target)
1145
            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
1146

1147
            if args.distributed:
1148
                reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
1149
                acc1 = utils.reduce_tensor(acc1, args.world_size)
1150
                acc5 = utils.reduce_tensor(acc5, args.world_size)
1151
            else:
1152
                reduced_loss = loss.data
1153

1154
            if device.type == 'cuda':
1155
                torch.cuda.synchronize()
1156

1157
            losses_m.update(reduced_loss.item(), input.size(0))
1158
            top1_m.update(acc1.item(), output.size(0))
1159
            top5_m.update(acc5.item(), output.size(0))
1160

1161
            batch_time_m.update(time.time() - end)
1162
            end = time.time()
1163
            if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0):
1164
                log_name = 'Test' + log_suffix
1165
                _logger.info(
1166
                    f'{log_name}: [{batch_idx:>4d}/{last_idx}]  '
1167
                    f'Time: {batch_time_m.val:.3f} ({batch_time_m.avg:.3f})  '
1168
                    f'Loss: {losses_m.val:>7.3f} ({losses_m.avg:>6.3f})  '
1169
                    f'Acc@1: {top1_m.val:>7.3f} ({top1_m.avg:>7.3f})  '
1170
                    f'Acc@5: {top5_m.val:>7.3f} ({top5_m.avg:>7.3f})'
1171
                )
1172

1173
    metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
1174

1175
    return metrics
1176

1177

1178
if __name__ == '__main__':
1179
    main()
1180

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.