pytorch-image-models
129 строк · 4.5 Кб
1import logging
2from .constants import *
3
4
5_logger = logging.getLogger(__name__)
6
7
8def resolve_data_config(
9args=None,
10pretrained_cfg=None,
11model=None,
12use_test_size=False,
13verbose=False
14):
15assert model or args or pretrained_cfg, "At least one of model, args, or pretrained_cfg required for data config."
16args = args or {}
17pretrained_cfg = pretrained_cfg or {}
18if not pretrained_cfg and model is not None and hasattr(model, 'pretrained_cfg'):
19pretrained_cfg = model.pretrained_cfg
20data_config = {}
21
22# Resolve input/image size
23in_chans = 3
24if args.get('in_chans', None) is not None:
25in_chans = args['in_chans']
26elif args.get('chans', None) is not None:
27in_chans = args['chans']
28
29input_size = (in_chans, 224, 224)
30if args.get('input_size', None) is not None:
31assert isinstance(args['input_size'], (tuple, list))
32assert len(args['input_size']) == 3
33input_size = tuple(args['input_size'])
34in_chans = input_size[0] # input_size overrides in_chans
35elif args.get('img_size', None) is not None:
36assert isinstance(args['img_size'], int)
37input_size = (in_chans, args['img_size'], args['img_size'])
38else:
39if use_test_size and pretrained_cfg.get('test_input_size', None) is not None:
40input_size = pretrained_cfg['test_input_size']
41elif pretrained_cfg.get('input_size', None) is not None:
42input_size = pretrained_cfg['input_size']
43data_config['input_size'] = input_size
44
45# resolve interpolation method
46data_config['interpolation'] = 'bicubic'
47if args.get('interpolation', None):
48data_config['interpolation'] = args['interpolation']
49elif pretrained_cfg.get('interpolation', None):
50data_config['interpolation'] = pretrained_cfg['interpolation']
51
52# resolve dataset + model mean for normalization
53data_config['mean'] = IMAGENET_DEFAULT_MEAN
54if args.get('mean', None) is not None:
55mean = tuple(args['mean'])
56if len(mean) == 1:
57mean = tuple(list(mean) * in_chans)
58else:
59assert len(mean) == in_chans
60data_config['mean'] = mean
61elif pretrained_cfg.get('mean', None):
62data_config['mean'] = pretrained_cfg['mean']
63
64# resolve dataset + model std deviation for normalization
65data_config['std'] = IMAGENET_DEFAULT_STD
66if args.get('std', None) is not None:
67std = tuple(args['std'])
68if len(std) == 1:
69std = tuple(list(std) * in_chans)
70else:
71assert len(std) == in_chans
72data_config['std'] = std
73elif pretrained_cfg.get('std', None):
74data_config['std'] = pretrained_cfg['std']
75
76# resolve default inference crop
77crop_pct = DEFAULT_CROP_PCT
78if args.get('crop_pct', None):
79crop_pct = args['crop_pct']
80else:
81if use_test_size and pretrained_cfg.get('test_crop_pct', None):
82crop_pct = pretrained_cfg['test_crop_pct']
83elif pretrained_cfg.get('crop_pct', None):
84crop_pct = pretrained_cfg['crop_pct']
85data_config['crop_pct'] = crop_pct
86
87# resolve default crop percentage
88crop_mode = DEFAULT_CROP_MODE
89if args.get('crop_mode', None):
90crop_mode = args['crop_mode']
91elif pretrained_cfg.get('crop_mode', None):
92crop_mode = pretrained_cfg['crop_mode']
93data_config['crop_mode'] = crop_mode
94
95if verbose:
96_logger.info('Data processing configuration for current model + dataset:')
97for n, v in data_config.items():
98_logger.info('\t%s: %s' % (n, str(v)))
99
100return data_config
101
102
103def resolve_model_data_config(
104model,
105args=None,
106pretrained_cfg=None,
107use_test_size=False,
108verbose=False,
109):
110""" Resolve Model Data Config
111This is equivalent to resolve_data_config() but with arguments re-ordered to put model first.
112
113Args:
114model (nn.Module): the model instance
115args (dict): command line arguments / configuration in dict form (overrides pretrained_cfg)
116pretrained_cfg (dict): pretrained model config (overrides pretrained_cfg attached to model)
117use_test_size (bool): use the test time input resolution (if one exists) instead of default train resolution
118verbose (bool): enable extra logging of resolved values
119
120Returns:
121dictionary of config
122"""
123return resolve_data_config(
124args=args,
125pretrained_cfg=pretrained_cfg,
126model=model,
127use_test_size=use_test_size,
128verbose=verbose,
129)
130