pytorch-image-models
/
train.py
1179 строк · 53.4 Кб
1#!/usr/bin/env python3
2""" ImageNet Training Script
3
4This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet
5training results with some of the latest networks and training techniques. It favours canonical PyTorch
6and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed
7and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit.
8
9This script was started from an early version of the PyTorch ImageNet example
10(https://github.com/pytorch/examples/tree/master/imagenet)
11
12NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
13(https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
14
15Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
16"""
17import argparse
18import importlib
19import json
20import logging
21import os
22import time
23from collections import OrderedDict
24from contextlib import suppress
25from datetime import datetime
26from functools import partial
27
28import torch
29import torch.nn as nn
30import torchvision.utils
31import yaml
32from torch.nn.parallel import DistributedDataParallel as NativeDDP
33
34from timm import utils
35from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
36from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm
37from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy
38from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters
39from timm.optim import create_optimizer_v2, optimizer_kwargs
40from timm.scheduler import create_scheduler_v2, scheduler_kwargs
41from timm.utils import ApexScaler, NativeScaler
42
43try:
44from apex import amp
45from apex.parallel import DistributedDataParallel as ApexDDP
46from apex.parallel import convert_syncbn_model
47has_apex = True
48except ImportError:
49has_apex = False
50
51has_native_amp = False
52try:
53if getattr(torch.cuda.amp, 'autocast') is not None:
54has_native_amp = True
55except AttributeError:
56pass
57
58try:
59import wandb
60has_wandb = True
61except ImportError:
62has_wandb = False
63
64try:
65from functorch.compile import memory_efficient_fusion
66has_functorch = True
67except ImportError as e:
68has_functorch = False
69
70has_compile = hasattr(torch, 'compile')
71
72
73_logger = logging.getLogger('train')
74
75# The first arg parser parses out only the --config argument, this argument is used to
76# load a yaml file containing key-values that override the defaults for the main parser below
77config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
78parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
79help='YAML config file specifying default arguments')
80
81
82parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
83
84# Dataset parameters
85group = parser.add_argument_group('Dataset parameters')
86# Keep this argument outside the dataset group because it is positional.
87parser.add_argument('data', nargs='?', metavar='DIR', const=None,
88help='path to dataset (positional is *deprecated*, use --data-dir)')
89parser.add_argument('--data-dir', metavar='DIR',
90help='path to dataset (root dir)')
91parser.add_argument('--dataset', metavar='NAME', default='',
92help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
93group.add_argument('--train-split', metavar='NAME', default='train',
94help='dataset train split (default: train)')
95group.add_argument('--val-split', metavar='NAME', default='validation',
96help='dataset validation split (default: validation)')
97parser.add_argument('--train-num-samples', default=None, type=int,
98metavar='N', help='Manually specify num samples in train split, for IterableDatasets.')
99parser.add_argument('--val-num-samples', default=None, type=int,
100metavar='N', help='Manually specify num samples in validation split, for IterableDatasets.')
101group.add_argument('--dataset-download', action='store_true', default=False,
102help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
103group.add_argument('--class-map', default='', type=str, metavar='FILENAME',
104help='path to class to idx mapping file (default: "")')
105group.add_argument('--input-img-mode', default=None, type=str,
106help='Dataset image conversion mode for input images.')
107group.add_argument('--input-key', default=None, type=str,
108help='Dataset key for input images.')
109group.add_argument('--target-key', default=None, type=str,
110help='Dataset key for target labels.')
111
112# Model parameters
113group = parser.add_argument_group('Model parameters')
114group.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
115help='Name of model to train (default: "resnet50")')
116group.add_argument('--pretrained', action='store_true', default=False,
117help='Start with pretrained version of specified network (if avail)')
118group.add_argument('--pretrained-path', default=None, type=str,
119help='Load this checkpoint as if they were the pretrained weights (with adaptation).')
120group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
121help='Load this checkpoint into model after initialization (default: none)')
122group.add_argument('--resume', default='', type=str, metavar='PATH',
123help='Resume full model and optimizer state from checkpoint (default: none)')
124group.add_argument('--no-resume-opt', action='store_true', default=False,
125help='prevent resume of optimizer state when resuming model')
126group.add_argument('--num-classes', type=int, default=None, metavar='N',
127help='number of label classes (Model default if None)')
128group.add_argument('--gp', default=None, type=str, metavar='POOL',
129help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
130group.add_argument('--img-size', type=int, default=None, metavar='N',
131help='Image size (default: None => model default)')
132group.add_argument('--in-chans', type=int, default=None, metavar='N',
133help='Image input channels (default: None => 3)')
134group.add_argument('--input-size', default=None, nargs=3, type=int,
135metavar='N N N',
136help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
137group.add_argument('--crop-pct', default=None, type=float,
138metavar='N', help='Input image center crop percent (for validation only)')
139group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
140help='Override mean pixel value of dataset')
141group.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
142help='Override std deviation of dataset')
143group.add_argument('--interpolation', default='', type=str, metavar='NAME',
144help='Image resize interpolation type (overrides model)')
145group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
146help='Input batch size for training (default: 128)')
147group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
148help='Validation batch size override (default: None)')
149group.add_argument('--channels-last', action='store_true', default=False,
150help='Use channels_last memory layout')
151group.add_argument('--fuser', default='', type=str,
152help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
153group.add_argument('--grad-accum-steps', type=int, default=1, metavar='N',
154help='The number of steps to accumulate gradients (default: 1)')
155group.add_argument('--grad-checkpointing', action='store_true', default=False,
156help='Enable gradient checkpointing through model blocks/stages')
157group.add_argument('--fast-norm', default=False, action='store_true',
158help='enable experimental fast-norm')
159group.add_argument('--model-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
160group.add_argument('--head-init-scale', default=None, type=float,
161help='Head initialization scale')
162group.add_argument('--head-init-bias', default=None, type=float,
163help='Head initialization bias value')
164
165# scripting / codegen
166scripting_group = group.add_mutually_exclusive_group()
167scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
168help='torch.jit.script the full model')
169scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
170help="Enable compilation w/ specified backend (default: inductor).")
171
172# Device & distributed
173group = parser.add_argument_group('Device parameters')
174group.add_argument('--device', default='cuda', type=str,
175help="Device (accelerator) to use.")
176group.add_argument('--amp', action='store_true', default=False,
177help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
178group.add_argument('--amp-dtype', default='float16', type=str,
179help='lower precision AMP dtype (default: float16)')
180group.add_argument('--amp-impl', default='native', type=str,
181help='AMP impl to use, "native" or "apex" (default: native)')
182group.add_argument('--no-ddp-bb', action='store_true', default=False,
183help='Force broadcast buffers for native DDP to off.')
184group.add_argument('--synchronize-step', action='store_true', default=False,
185help='torch.cuda.synchronize() end of each step')
186group.add_argument("--local_rank", default=0, type=int)
187parser.add_argument('--device-modules', default=None, type=str, nargs='+',
188help="Python imports for device backend modules.")
189
190# Optimizer parameters
191group = parser.add_argument_group('Optimizer parameters')
192group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
193help='Optimizer (default: "sgd")')
194group.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
195help='Optimizer Epsilon (default: None, use opt default)')
196group.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
197help='Optimizer Betas (default: None, use opt default)')
198group.add_argument('--momentum', type=float, default=0.9, metavar='M',
199help='Optimizer momentum (default: 0.9)')
200group.add_argument('--weight-decay', type=float, default=2e-5,
201help='weight decay (default: 2e-5)')
202group.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
203help='Clip gradient norm (default: None, no clipping)')
204group.add_argument('--clip-mode', type=str, default='norm',
205help='Gradient clipping mode. One of ("norm", "value", "agc")')
206group.add_argument('--layer-decay', type=float, default=None,
207help='layer-wise learning rate decay (default: None)')
208group.add_argument('--opt-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
209
210# Learning rate schedule parameters
211group = parser.add_argument_group('Learning rate schedule parameters')
212group.add_argument('--sched', type=str, default='cosine', metavar='SCHEDULER',
213help='LR scheduler (default: "step"')
214group.add_argument('--sched-on-updates', action='store_true', default=False,
215help='Apply LR scheduler step on update instead of epoch end.')
216group.add_argument('--lr', type=float, default=None, metavar='LR',
217help='learning rate, overrides lr-base if set (default: None)')
218group.add_argument('--lr-base', type=float, default=0.1, metavar='LR',
219help='base learning rate: lr = lr_base * global_batch_size / base_size')
220group.add_argument('--lr-base-size', type=int, default=256, metavar='DIV',
221help='base learning rate batch size (divisor, default: 256).')
222group.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE',
223help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)')
224group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
225help='learning rate noise on/off epoch percentages')
226group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
227help='learning rate noise limit percent (default: 0.67)')
228group.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
229help='learning rate noise std-dev (default: 1.0)')
230group.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
231help='learning rate cycle len multiplier (default: 1.0)')
232group.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',
233help='amount to decay each learning rate cycle (default: 0.5)')
234group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
235help='learning rate cycle limit, cycles enabled if > 1')
236group.add_argument('--lr-k-decay', type=float, default=1.0,
237help='learning rate k-decay for cosine/poly (default: 1.0)')
238group.add_argument('--warmup-lr', type=float, default=1e-5, metavar='LR',
239help='warmup learning rate (default: 1e-5)')
240group.add_argument('--min-lr', type=float, default=0, metavar='LR',
241help='lower lr bound for cyclic schedulers that hit 0 (default: 0)')
242group.add_argument('--epochs', type=int, default=300, metavar='N',
243help='number of epochs to train (default: 300)')
244group.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
245help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
246group.add_argument('--start-epoch', default=None, type=int, metavar='N',
247help='manual epoch number (useful on restarts)')
248group.add_argument('--decay-milestones', default=[90, 180, 270], type=int, nargs='+', metavar="MILESTONES",
249help='list of decay epoch indices for multistep lr. must be increasing')
250group.add_argument('--decay-epochs', type=float, default=90, metavar='N',
251help='epoch interval to decay LR')
252group.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
253help='epochs to warmup LR, if scheduler supports')
254group.add_argument('--warmup-prefix', action='store_true', default=False,
255help='Exclude warmup period from decay schedule.'),
256group.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',
257help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
258group.add_argument('--patience-epochs', type=int, default=10, metavar='N',
259help='patience epochs for Plateau LR scheduler (default: 10)')
260group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
261help='LR decay rate (default: 0.1)')
262
263# Augmentation & regularization parameters
264group = parser.add_argument_group('Augmentation and regularization parameters')
265group.add_argument('--no-aug', action='store_true', default=False,
266help='Disable all training augmentation, override other train aug args')
267group.add_argument('--train-crop-mode', type=str, default=None,
268help='Crop-mode in train'),
269group.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
270help='Random resize scale (default: 0.08 1.0)')
271group.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
272help='Random resize aspect ratio (default: 0.75 1.33)')
273group.add_argument('--hflip', type=float, default=0.5,
274help='Horizontal flip training aug probability')
275group.add_argument('--vflip', type=float, default=0.,
276help='Vertical flip training aug probability')
277group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
278help='Color jitter factor (default: 0.4)')
279group.add_argument('--color-jitter-prob', type=float, default=None, metavar='PCT',
280help='Probability of applying any color jitter.')
281group.add_argument('--grayscale-prob', type=float, default=None, metavar='PCT',
282help='Probability of applying random grayscale conversion.')
283group.add_argument('--gaussian-blur-prob', type=float, default=None, metavar='PCT',
284help='Probability of applying gaussian blur.')
285group.add_argument('--aa', type=str, default=None, metavar='NAME',
286help='Use AutoAugment policy. "v0" or "original". (default: None)'),
287group.add_argument('--aug-repeats', type=float, default=0,
288help='Number of augmentation repetitions (distributed training only) (default: 0)')
289group.add_argument('--aug-splits', type=int, default=0,
290help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
291group.add_argument('--jsd-loss', action='store_true', default=False,
292help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
293group.add_argument('--bce-loss', action='store_true', default=False,
294help='Enable BCE loss w/ Mixup/CutMix use.')
295group.add_argument('--bce-sum', action='store_true', default=False,
296help='Sum over classes when using BCE loss.')
297group.add_argument('--bce-target-thresh', type=float, default=None,
298help='Threshold for binarizing softened BCE targets (default: None, disabled).')
299group.add_argument('--bce-pos-weight', type=float, default=None,
300help='Positive weighting for BCE loss.')
301group.add_argument('--reprob', type=float, default=0., metavar='PCT',
302help='Random erase prob (default: 0.)')
303group.add_argument('--remode', type=str, default='pixel',
304help='Random erase mode (default: "pixel")')
305group.add_argument('--recount', type=int, default=1,
306help='Random erase count (default: 1)')
307group.add_argument('--resplit', action='store_true', default=False,
308help='Do not random erase first (clean) augmentation split')
309group.add_argument('--mixup', type=float, default=0.0,
310help='mixup alpha, mixup enabled if > 0. (default: 0.)')
311group.add_argument('--cutmix', type=float, default=0.0,
312help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
313group.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
314help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
315group.add_argument('--mixup-prob', type=float, default=1.0,
316help='Probability of performing mixup or cutmix when either/both is enabled')
317group.add_argument('--mixup-switch-prob', type=float, default=0.5,
318help='Probability of switching to cutmix when both mixup and cutmix enabled')
319group.add_argument('--mixup-mode', type=str, default='batch',
320help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
321group.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
322help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
323group.add_argument('--smoothing', type=float, default=0.1,
324help='Label smoothing (default: 0.1)')
325group.add_argument('--train-interpolation', type=str, default='random',
326help='Training interpolation (random, bilinear, bicubic default: "random")')
327group.add_argument('--drop', type=float, default=0.0, metavar='PCT',
328help='Dropout rate (default: 0.)')
329group.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
330help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
331group.add_argument('--drop-path', type=float, default=None, metavar='PCT',
332help='Drop path rate (default: None)')
333group.add_argument('--drop-block', type=float, default=None, metavar='PCT',
334help='Drop block rate (default: None)')
335
336# Batch norm parameters (only works with gen_efficientnet based models currently)
337group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.')
338group.add_argument('--bn-momentum', type=float, default=None,
339help='BatchNorm momentum override (if not None)')
340group.add_argument('--bn-eps', type=float, default=None,
341help='BatchNorm epsilon override (if not None)')
342group.add_argument('--sync-bn', action='store_true',
343help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
344group.add_argument('--dist-bn', type=str, default='reduce',
345help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
346group.add_argument('--split-bn', action='store_true',
347help='Enable separate BN layers per augmentation split.')
348
349# Model Exponential Moving Average
350group = parser.add_argument_group('Model exponential moving average parameters')
351group.add_argument('--model-ema', action='store_true', default=False,
352help='Enable tracking moving average of model weights.')
353group.add_argument('--model-ema-force-cpu', action='store_true', default=False,
354help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
355group.add_argument('--model-ema-decay', type=float, default=0.9998,
356help='Decay factor for model weights moving average (default: 0.9998)')
357group.add_argument('--model-ema-warmup', action='store_true',
358help='Enable warmup for model EMA decay.')
359
360# Misc
361group = parser.add_argument_group('Miscellaneous parameters')
362group.add_argument('--seed', type=int, default=42, metavar='S',
363help='random seed (default: 42)')
364group.add_argument('--worker-seeding', type=str, default='all',
365help='worker seed mode (default: all)')
366group.add_argument('--log-interval', type=int, default=50, metavar='N',
367help='how many batches to wait before logging training status')
368group.add_argument('--recovery-interval', type=int, default=0, metavar='N',
369help='how many batches to wait before writing recovery checkpoint')
370group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',
371help='number of checkpoints to keep (default: 10)')
372group.add_argument('-j', '--workers', type=int, default=4, metavar='N',
373help='how many training processes to use (default: 4)')
374group.add_argument('--save-images', action='store_true', default=False,
375help='save images of input bathes every log interval for debugging')
376group.add_argument('--pin-mem', action='store_true', default=False,
377help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
378group.add_argument('--no-prefetcher', action='store_true', default=False,
379help='disable fast prefetcher')
380group.add_argument('--output', default='', type=str, metavar='PATH',
381help='path to output folder (default: none, current dir)')
382group.add_argument('--experiment', default='', type=str, metavar='NAME',
383help='name of train experiment, name of sub-folder for output')
384group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
385help='Best metric (default: "top1"')
386group.add_argument('--tta', type=int, default=0, metavar='N',
387help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
388group.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
389help='use the multi-epochs-loader to save time at the beginning of every epoch')
390group.add_argument('--log-wandb', action='store_true', default=False,
391help='log training and validation metrics to wandb')
392
393
394def _parse_args():
395# Do we have a config file to parse?
396args_config, remaining = config_parser.parse_known_args()
397if args_config.config:
398with open(args_config.config, 'r') as f:
399cfg = yaml.safe_load(f)
400parser.set_defaults(**cfg)
401
402# The main arg parser parses the rest of the args, the usual
403# defaults will have been overridden if config file specified.
404args = parser.parse_args(remaining)
405
406# Cache the args as a text string to save them in the output dir later
407args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
408return args, args_text
409
410
411def main():
412utils.setup_default_logging()
413args, args_text = _parse_args()
414
415if args.device_modules:
416for module in args.device_modules:
417importlib.import_module(module)
418
419if torch.cuda.is_available():
420torch.backends.cuda.matmul.allow_tf32 = True
421torch.backends.cudnn.benchmark = True
422
423args.prefetcher = not args.no_prefetcher
424args.grad_accum_steps = max(1, args.grad_accum_steps)
425device = utils.init_distributed_device(args)
426if args.distributed:
427_logger.info(
428'Training in distributed mode with multiple processes, 1 device per process.'
429f'Process {args.rank}, total {args.world_size}, device {args.device}.')
430else:
431_logger.info(f'Training with a single process on 1 device ({args.device}).')
432assert args.rank >= 0
433
434# resolve AMP arguments based on PyTorch / Apex availability
435use_amp = None
436amp_dtype = torch.float16
437if args.amp:
438if args.amp_impl == 'apex':
439assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
440use_amp = 'apex'
441assert args.amp_dtype == 'float16'
442else:
443assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
444use_amp = 'native'
445assert args.amp_dtype in ('float16', 'bfloat16')
446if args.amp_dtype == 'bfloat16':
447amp_dtype = torch.bfloat16
448
449utils.random_seed(args.seed, args.rank)
450
451if args.fuser:
452utils.set_jit_fuser(args.fuser)
453if args.fast_norm:
454set_fast_norm()
455
456in_chans = 3
457if args.in_chans is not None:
458in_chans = args.in_chans
459elif args.input_size is not None:
460in_chans = args.input_size[0]
461
462factory_kwargs = {}
463if args.pretrained_path:
464# merge with pretrained_cfg of model, 'file' has priority over 'url' and 'hf_hub'.
465factory_kwargs['pretrained_cfg_overlay'] = dict(
466file=args.pretrained_path,
467num_classes=-1, # force head adaptation
468)
469
470model = create_model(
471args.model,
472pretrained=args.pretrained,
473in_chans=in_chans,
474num_classes=args.num_classes,
475drop_rate=args.drop,
476drop_path_rate=args.drop_path,
477drop_block_rate=args.drop_block,
478global_pool=args.gp,
479bn_momentum=args.bn_momentum,
480bn_eps=args.bn_eps,
481scriptable=args.torchscript,
482checkpoint_path=args.initial_checkpoint,
483**factory_kwargs,
484**args.model_kwargs,
485)
486if args.head_init_scale is not None:
487with torch.no_grad():
488model.get_classifier().weight.mul_(args.head_init_scale)
489model.get_classifier().bias.mul_(args.head_init_scale)
490if args.head_init_bias is not None:
491nn.init.constant_(model.get_classifier().bias, args.head_init_bias)
492
493if args.num_classes is None:
494assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
495args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
496
497if args.grad_checkpointing:
498model.set_grad_checkpointing(enable=True)
499
500if utils.is_primary(args):
501_logger.info(
502f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
503
504data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args))
505
506# setup augmentation batch splits for contrastive loss or split bn
507num_aug_splits = 0
508if args.aug_splits > 0:
509assert args.aug_splits > 1, 'A split of 1 makes no sense'
510num_aug_splits = args.aug_splits
511
512# enable split bn (separate bn stats per batch-portion)
513if args.split_bn:
514assert num_aug_splits > 1 or args.resplit
515model = convert_splitbn_model(model, max(num_aug_splits, 2))
516
517# move model to GPU, enable channels last layout if set
518model.to(device=device)
519if args.channels_last:
520model.to(memory_format=torch.channels_last)
521
522# setup synchronized BatchNorm for distributed training
523if args.distributed and args.sync_bn:
524args.dist_bn = '' # disable dist_bn when sync BN active
525assert not args.split_bn
526if has_apex and use_amp == 'apex':
527# Apex SyncBN used with Apex AMP
528# WARNING this won't currently work with models using BatchNormAct2d
529model = convert_syncbn_model(model)
530else:
531model = convert_sync_batchnorm(model)
532if utils.is_primary(args):
533_logger.info(
534'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
535'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
536
537if args.torchscript:
538assert not args.torchcompile
539assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
540assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
541model = torch.jit.script(model)
542
543if not args.lr:
544global_batch_size = args.batch_size * args.world_size * args.grad_accum_steps
545batch_ratio = global_batch_size / args.lr_base_size
546if not args.lr_base_scale:
547on = args.opt.lower()
548args.lr_base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear'
549if args.lr_base_scale == 'sqrt':
550batch_ratio = batch_ratio ** 0.5
551args.lr = args.lr_base * batch_ratio
552if utils.is_primary(args):
553_logger.info(
554f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
555f'and effective global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')
556
557optimizer = create_optimizer_v2(
558model,
559**optimizer_kwargs(cfg=args),
560**args.opt_kwargs,
561)
562
563# setup automatic mixed-precision (AMP) loss scaling and op casting
564amp_autocast = suppress # do nothing
565loss_scaler = None
566if use_amp == 'apex':
567assert device.type == 'cuda'
568model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
569loss_scaler = ApexScaler()
570if utils.is_primary(args):
571_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
572elif use_amp == 'native':
573try:
574amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
575except (AttributeError, TypeError):
576# fallback to CUDA only AMP for PyTorch < 1.10
577assert device.type == 'cuda'
578amp_autocast = torch.cuda.amp.autocast
579if device.type == 'cuda' and amp_dtype == torch.float16:
580# loss scaler only used for float16 (half) dtype, bfloat16 does not need it
581loss_scaler = NativeScaler()
582if utils.is_primary(args):
583_logger.info('Using native Torch AMP. Training in mixed precision.')
584else:
585if utils.is_primary(args):
586_logger.info('AMP not enabled. Training in float32.')
587
588# optionally resume from a checkpoint
589resume_epoch = None
590if args.resume:
591resume_epoch = resume_checkpoint(
592model,
593args.resume,
594optimizer=None if args.no_resume_opt else optimizer,
595loss_scaler=None if args.no_resume_opt else loss_scaler,
596log_info=utils.is_primary(args),
597)
598
599# setup exponential moving average of model weights, SWA could be used here too
600model_ema = None
601if args.model_ema:
602# Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
603model_ema = utils.ModelEmaV3(
604model,
605decay=args.model_ema_decay,
606use_warmup=args.model_ema_warmup,
607device='cpu' if args.model_ema_force_cpu else None,
608)
609if args.resume:
610load_checkpoint(model_ema.module, args.resume, use_ema=True)
611if args.torchcompile:
612model_ema = torch.compile(model_ema, backend=args.torchcompile)
613
614# setup distributed training
615if args.distributed:
616if has_apex and use_amp == 'apex':
617# Apex DDP preferred unless native amp is activated
618if utils.is_primary(args):
619_logger.info("Using NVIDIA APEX DistributedDataParallel.")
620model = ApexDDP(model, delay_allreduce=True)
621else:
622if utils.is_primary(args):
623_logger.info("Using native Torch DistributedDataParallel.")
624model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb)
625# NOTE: EMA model does not need to be wrapped by DDP
626
627if args.torchcompile:
628# torch compile should be done after DDP
629assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
630model = torch.compile(model, backend=args.torchcompile)
631
632# create the train and eval datasets
633if args.data and not args.data_dir:
634args.data_dir = args.data
635if args.input_img_mode is None:
636input_img_mode = 'RGB' if data_config['input_size'][0] == 3 else 'L'
637else:
638input_img_mode = args.input_img_mode
639
640dataset_train = create_dataset(
641args.dataset,
642root=args.data_dir,
643split=args.train_split,
644is_training=True,
645class_map=args.class_map,
646download=args.dataset_download,
647batch_size=args.batch_size,
648seed=args.seed,
649repeats=args.epoch_repeats,
650input_img_mode=input_img_mode,
651input_key=args.input_key,
652target_key=args.target_key,
653num_samples=args.train_num_samples,
654)
655
656if args.val_split:
657dataset_eval = create_dataset(
658args.dataset,
659root=args.data_dir,
660split=args.val_split,
661is_training=False,
662class_map=args.class_map,
663download=args.dataset_download,
664batch_size=args.batch_size,
665input_img_mode=input_img_mode,
666input_key=args.input_key,
667target_key=args.target_key,
668num_samples=args.val_num_samples,
669)
670
671# setup mixup / cutmix
672collate_fn = None
673mixup_fn = None
674mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
675if mixup_active:
676mixup_args = dict(
677mixup_alpha=args.mixup,
678cutmix_alpha=args.cutmix,
679cutmix_minmax=args.cutmix_minmax,
680prob=args.mixup_prob,
681switch_prob=args.mixup_switch_prob,
682mode=args.mixup_mode,
683label_smoothing=args.smoothing,
684num_classes=args.num_classes
685)
686if args.prefetcher:
687assert not num_aug_splits # collate conflict (need to support de-interleaving in collate mixup)
688collate_fn = FastCollateMixup(**mixup_args)
689else:
690mixup_fn = Mixup(**mixup_args)
691
692# wrap dataset in AugMix helper
693if num_aug_splits > 1:
694dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
695
696# create data loaders w/ augmentation pipeline
697train_interpolation = args.train_interpolation
698if args.no_aug or not train_interpolation:
699train_interpolation = data_config['interpolation']
700loader_train = create_loader(
701dataset_train,
702input_size=data_config['input_size'],
703batch_size=args.batch_size,
704is_training=True,
705no_aug=args.no_aug,
706re_prob=args.reprob,
707re_mode=args.remode,
708re_count=args.recount,
709re_split=args.resplit,
710train_crop_mode=args.train_crop_mode,
711scale=args.scale,
712ratio=args.ratio,
713hflip=args.hflip,
714vflip=args.vflip,
715color_jitter=args.color_jitter,
716color_jitter_prob=args.color_jitter_prob,
717grayscale_prob=args.grayscale_prob,
718gaussian_blur_prob=args.gaussian_blur_prob,
719auto_augment=args.aa,
720num_aug_repeats=args.aug_repeats,
721num_aug_splits=num_aug_splits,
722interpolation=train_interpolation,
723mean=data_config['mean'],
724std=data_config['std'],
725num_workers=args.workers,
726distributed=args.distributed,
727collate_fn=collate_fn,
728pin_memory=args.pin_mem,
729device=device,
730use_prefetcher=args.prefetcher,
731use_multi_epochs_loader=args.use_multi_epochs_loader,
732worker_seeding=args.worker_seeding,
733)
734
735loader_eval = None
736if args.val_split:
737eval_workers = args.workers
738if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset):
739# FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
740eval_workers = min(2, args.workers)
741loader_eval = create_loader(
742dataset_eval,
743input_size=data_config['input_size'],
744batch_size=args.validation_batch_size or args.batch_size,
745is_training=False,
746interpolation=data_config['interpolation'],
747mean=data_config['mean'],
748std=data_config['std'],
749num_workers=eval_workers,
750distributed=args.distributed,
751crop_pct=data_config['crop_pct'],
752pin_memory=args.pin_mem,
753device=device,
754use_prefetcher=args.prefetcher,
755)
756
757# setup loss function
758if args.jsd_loss:
759assert num_aug_splits > 1 # JSD only valid with aug splits set
760train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)
761elif mixup_active:
762# smoothing is handled with mixup target transform which outputs sparse, soft targets
763if args.bce_loss:
764train_loss_fn = BinaryCrossEntropy(
765target_threshold=args.bce_target_thresh,
766sum_classes=args.bce_sum,
767pos_weight=args.bce_pos_weight,
768)
769else:
770train_loss_fn = SoftTargetCrossEntropy()
771elif args.smoothing:
772if args.bce_loss:
773train_loss_fn = BinaryCrossEntropy(
774smoothing=args.smoothing,
775target_threshold=args.bce_target_thresh,
776sum_classes=args.bce_sum,
777pos_weight=args.bce_pos_weight,
778)
779else:
780train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
781else:
782train_loss_fn = nn.CrossEntropyLoss()
783train_loss_fn = train_loss_fn.to(device=device)
784validate_loss_fn = nn.CrossEntropyLoss().to(device=device)
785
786# setup checkpoint saver and eval metric tracking
787eval_metric = args.eval_metric if loader_eval is not None else 'loss'
788decreasing_metric = eval_metric == 'loss'
789best_metric = None
790best_epoch = None
791saver = None
792output_dir = None
793if utils.is_primary(args):
794if args.experiment:
795exp_name = args.experiment
796else:
797exp_name = '-'.join([
798datetime.now().strftime("%Y%m%d-%H%M%S"),
799safe_model_name(args.model),
800str(data_config['input_size'][-1])
801])
802output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
803saver = utils.CheckpointSaver(
804model=model,
805optimizer=optimizer,
806args=args,
807model_ema=model_ema,
808amp_scaler=loss_scaler,
809checkpoint_dir=output_dir,
810recovery_dir=output_dir,
811decreasing=decreasing_metric,
812max_history=args.checkpoint_hist
813)
814with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
815f.write(args_text)
816
817if utils.is_primary(args) and args.log_wandb:
818if has_wandb:
819wandb.init(project=args.experiment, config=args)
820else:
821_logger.warning(
822"You've requested to log metrics to wandb but package not found. "
823"Metrics not being logged to wandb, try `pip install wandb`")
824
825# setup learning rate schedule and starting epoch
826updates_per_epoch = (len(loader_train) + args.grad_accum_steps - 1) // args.grad_accum_steps
827lr_scheduler, num_epochs = create_scheduler_v2(
828optimizer,
829**scheduler_kwargs(args, decreasing_metric=decreasing_metric),
830updates_per_epoch=updates_per_epoch,
831)
832start_epoch = 0
833if args.start_epoch is not None:
834# a specified start_epoch will always override the resume epoch
835start_epoch = args.start_epoch
836elif resume_epoch is not None:
837start_epoch = resume_epoch
838if lr_scheduler is not None and start_epoch > 0:
839if args.sched_on_updates:
840lr_scheduler.step_update(start_epoch * updates_per_epoch)
841else:
842lr_scheduler.step(start_epoch)
843
844if utils.is_primary(args):
845_logger.info(
846f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.')
847
848results = []
849try:
850for epoch in range(start_epoch, num_epochs):
851if hasattr(dataset_train, 'set_epoch'):
852dataset_train.set_epoch(epoch)
853elif args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
854loader_train.sampler.set_epoch(epoch)
855
856train_metrics = train_one_epoch(
857epoch,
858model,
859loader_train,
860optimizer,
861train_loss_fn,
862args,
863lr_scheduler=lr_scheduler,
864saver=saver,
865output_dir=output_dir,
866amp_autocast=amp_autocast,
867loss_scaler=loss_scaler,
868model_ema=model_ema,
869mixup_fn=mixup_fn,
870num_updates_total=num_epochs * updates_per_epoch,
871)
872
873if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
874if utils.is_primary(args):
875_logger.info("Distributing BatchNorm running means and vars")
876utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
877
878if loader_eval is not None:
879eval_metrics = validate(
880model,
881loader_eval,
882validate_loss_fn,
883args,
884device=device,
885amp_autocast=amp_autocast,
886)
887
888if model_ema is not None and not args.model_ema_force_cpu:
889if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
890utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
891
892ema_eval_metrics = validate(
893model_ema,
894loader_eval,
895validate_loss_fn,
896args,
897device=device,
898amp_autocast=amp_autocast,
899log_suffix=' (EMA)',
900)
901eval_metrics = ema_eval_metrics
902else:
903eval_metrics = None
904
905if output_dir is not None:
906lrs = [param_group['lr'] for param_group in optimizer.param_groups]
907utils.update_summary(
908epoch,
909train_metrics,
910eval_metrics,
911filename=os.path.join(output_dir, 'summary.csv'),
912lr=sum(lrs) / len(lrs),
913write_header=best_metric is None,
914log_wandb=args.log_wandb and has_wandb,
915)
916
917if eval_metrics is not None:
918latest_metric = eval_metrics[eval_metric]
919else:
920latest_metric = train_metrics[eval_metric]
921
922if saver is not None:
923# save proper checkpoint with eval metric
924best_metric, best_epoch = saver.save_checkpoint(epoch, metric=latest_metric)
925
926if lr_scheduler is not None:
927# step LR for next epoch
928lr_scheduler.step(epoch + 1, latest_metric)
929
930results.append({
931'epoch': epoch,
932'train': train_metrics,
933'validation': eval_metrics,
934})
935
936except KeyboardInterrupt:
937pass
938
939results = {'all': results}
940if best_metric is not None:
941results['best'] = results['all'][best_epoch - start_epoch]
942_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
943print(f'--result\n{json.dumps(results, indent=4)}')
944
945
946def train_one_epoch(
947epoch,
948model,
949loader,
950optimizer,
951loss_fn,
952args,
953device=torch.device('cuda'),
954lr_scheduler=None,
955saver=None,
956output_dir=None,
957amp_autocast=suppress,
958loss_scaler=None,
959model_ema=None,
960mixup_fn=None,
961num_updates_total=None,
962):
963if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
964if args.prefetcher and loader.mixup_enabled:
965loader.mixup_enabled = False
966elif mixup_fn is not None:
967mixup_fn.mixup_enabled = False
968
969second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
970has_no_sync = hasattr(model, "no_sync")
971update_time_m = utils.AverageMeter()
972data_time_m = utils.AverageMeter()
973losses_m = utils.AverageMeter()
974
975model.train()
976
977accum_steps = args.grad_accum_steps
978last_accum_steps = len(loader) % accum_steps
979updates_per_epoch = (len(loader) + accum_steps - 1) // accum_steps
980num_updates = epoch * updates_per_epoch
981last_batch_idx = len(loader) - 1
982last_batch_idx_to_accum = len(loader) - last_accum_steps
983
984data_start_time = update_start_time = time.time()
985optimizer.zero_grad()
986update_sample_count = 0
987for batch_idx, (input, target) in enumerate(loader):
988last_batch = batch_idx == last_batch_idx
989need_update = last_batch or (batch_idx + 1) % accum_steps == 0
990update_idx = batch_idx // accum_steps
991if batch_idx >= last_batch_idx_to_accum:
992accum_steps = last_accum_steps
993
994if not args.prefetcher:
995input, target = input.to(device), target.to(device)
996if mixup_fn is not None:
997input, target = mixup_fn(input, target)
998if args.channels_last:
999input = input.contiguous(memory_format=torch.channels_last)
1000
1001# multiply by accum steps to get equivalent for full update
1002data_time_m.update(accum_steps * (time.time() - data_start_time))
1003
1004def _forward():
1005with amp_autocast():
1006output = model(input)
1007loss = loss_fn(output, target)
1008if accum_steps > 1:
1009loss /= accum_steps
1010return loss
1011
1012def _backward(_loss):
1013if loss_scaler is not None:
1014loss_scaler(
1015_loss,
1016optimizer,
1017clip_grad=args.clip_grad,
1018clip_mode=args.clip_mode,
1019parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
1020create_graph=second_order,
1021need_update=need_update,
1022)
1023else:
1024_loss.backward(create_graph=second_order)
1025if need_update:
1026if args.clip_grad is not None:
1027utils.dispatch_clip_grad(
1028model_parameters(model, exclude_head='agc' in args.clip_mode),
1029value=args.clip_grad,
1030mode=args.clip_mode,
1031)
1032optimizer.step()
1033
1034if has_no_sync and not need_update:
1035with model.no_sync():
1036loss = _forward()
1037_backward(loss)
1038else:
1039loss = _forward()
1040_backward(loss)
1041
1042if not args.distributed:
1043losses_m.update(loss.item() * accum_steps, input.size(0))
1044update_sample_count += input.size(0)
1045
1046if not need_update:
1047data_start_time = time.time()
1048continue
1049
1050num_updates += 1
1051optimizer.zero_grad()
1052if model_ema is not None:
1053model_ema.update(model, step=num_updates)
1054
1055if args.synchronize_step and device.type == 'cuda':
1056torch.cuda.synchronize()
1057time_now = time.time()
1058update_time_m.update(time.time() - update_start_time)
1059update_start_time = time_now
1060
1061if update_idx % args.log_interval == 0:
1062lrl = [param_group['lr'] for param_group in optimizer.param_groups]
1063lr = sum(lrl) / len(lrl)
1064
1065if args.distributed:
1066reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
1067losses_m.update(reduced_loss.item() * accum_steps, input.size(0))
1068update_sample_count *= args.world_size
1069
1070if utils.is_primary(args):
1071_logger.info(
1072f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} '
1073f'({100. * update_idx / (updates_per_epoch - 1):>3.0f}%)] '
1074f'Loss: {losses_m.val:#.3g} ({losses_m.avg:#.3g}) '
1075f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s '
1076f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) '
1077f'LR: {lr:.3e} '
1078f'Data: {data_time_m.val:.3f} ({data_time_m.avg:.3f})'
1079)
1080
1081if args.save_images and output_dir:
1082torchvision.utils.save_image(
1083input,
1084os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
1085padding=0,
1086normalize=True
1087)
1088
1089if saver is not None and args.recovery_interval and (
1090(update_idx + 1) % args.recovery_interval == 0):
1091saver.save_recovery(epoch, batch_idx=update_idx)
1092
1093if lr_scheduler is not None:
1094lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
1095
1096update_sample_count = 0
1097data_start_time = time.time()
1098# end for
1099
1100if hasattr(optimizer, 'sync_lookahead'):
1101optimizer.sync_lookahead()
1102
1103return OrderedDict([('loss', losses_m.avg)])
1104
1105
1106def validate(
1107model,
1108loader,
1109loss_fn,
1110args,
1111device=torch.device('cuda'),
1112amp_autocast=suppress,
1113log_suffix=''
1114):
1115batch_time_m = utils.AverageMeter()
1116losses_m = utils.AverageMeter()
1117top1_m = utils.AverageMeter()
1118top5_m = utils.AverageMeter()
1119
1120model.eval()
1121
1122end = time.time()
1123last_idx = len(loader) - 1
1124with torch.no_grad():
1125for batch_idx, (input, target) in enumerate(loader):
1126last_batch = batch_idx == last_idx
1127if not args.prefetcher:
1128input = input.to(device)
1129target = target.to(device)
1130if args.channels_last:
1131input = input.contiguous(memory_format=torch.channels_last)
1132
1133with amp_autocast():
1134output = model(input)
1135if isinstance(output, (tuple, list)):
1136output = output[0]
1137
1138# augmentation reduction
1139reduce_factor = args.tta
1140if reduce_factor > 1:
1141output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
1142target = target[0:target.size(0):reduce_factor]
1143
1144loss = loss_fn(output, target)
1145acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
1146
1147if args.distributed:
1148reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
1149acc1 = utils.reduce_tensor(acc1, args.world_size)
1150acc5 = utils.reduce_tensor(acc5, args.world_size)
1151else:
1152reduced_loss = loss.data
1153
1154if device.type == 'cuda':
1155torch.cuda.synchronize()
1156
1157losses_m.update(reduced_loss.item(), input.size(0))
1158top1_m.update(acc1.item(), output.size(0))
1159top5_m.update(acc5.item(), output.size(0))
1160
1161batch_time_m.update(time.time() - end)
1162end = time.time()
1163if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0):
1164log_name = 'Test' + log_suffix
1165_logger.info(
1166f'{log_name}: [{batch_idx:>4d}/{last_idx}] '
1167f'Time: {batch_time_m.val:.3f} ({batch_time_m.avg:.3f}) '
1168f'Loss: {losses_m.val:>7.3f} ({losses_m.avg:>6.3f}) '
1169f'Acc@1: {top1_m.val:>7.3f} ({top1_m.avg:>7.3f}) '
1170f'Acc@5: {top5_m.val:>7.3f} ({top5_m.avg:>7.3f})'
1171)
1172
1173metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
1174
1175return metrics
1176
1177
1178if __name__ == '__main__':
1179main()
1180