pytorch-image-models

Форк
0
129 строк · 4.5 Кб
1
import logging
2
from .constants import *
3

4

5
_logger = logging.getLogger(__name__)
6

7

8
def resolve_data_config(
9
        args=None,
10
        pretrained_cfg=None,
11
        model=None,
12
        use_test_size=False,
13
        verbose=False
14
):
15
    assert model or args or pretrained_cfg, "At least one of model, args, or pretrained_cfg required for data config."
16
    args = args or {}
17
    pretrained_cfg = pretrained_cfg or {}
18
    if not pretrained_cfg and model is not None and hasattr(model, 'pretrained_cfg'):
19
        pretrained_cfg = model.pretrained_cfg
20
    data_config = {}
21

22
    # Resolve input/image size
23
    in_chans = 3
24
    if args.get('in_chans', None) is not None:
25
        in_chans = args['in_chans']
26
    elif args.get('chans', None) is not None:
27
        in_chans = args['chans']
28

29
    input_size = (in_chans, 224, 224)
30
    if args.get('input_size', None) is not None:
31
        assert isinstance(args['input_size'], (tuple, list))
32
        assert len(args['input_size']) == 3
33
        input_size = tuple(args['input_size'])
34
        in_chans = input_size[0]  # input_size overrides in_chans
35
    elif args.get('img_size', None) is not None:
36
        assert isinstance(args['img_size'], int)
37
        input_size = (in_chans, args['img_size'], args['img_size'])
38
    else:
39
        if use_test_size and pretrained_cfg.get('test_input_size', None) is not None:
40
            input_size = pretrained_cfg['test_input_size']
41
        elif pretrained_cfg.get('input_size', None) is not None:
42
            input_size = pretrained_cfg['input_size']
43
    data_config['input_size'] = input_size
44

45
    # resolve interpolation method
46
    data_config['interpolation'] = 'bicubic'
47
    if args.get('interpolation', None):
48
        data_config['interpolation'] = args['interpolation']
49
    elif pretrained_cfg.get('interpolation', None):
50
        data_config['interpolation'] = pretrained_cfg['interpolation']
51

52
    # resolve dataset + model mean for normalization
53
    data_config['mean'] = IMAGENET_DEFAULT_MEAN
54
    if args.get('mean', None) is not None:
55
        mean = tuple(args['mean'])
56
        if len(mean) == 1:
57
            mean = tuple(list(mean) * in_chans)
58
        else:
59
            assert len(mean) == in_chans
60
        data_config['mean'] = mean
61
    elif pretrained_cfg.get('mean', None):
62
        data_config['mean'] = pretrained_cfg['mean']
63

64
    # resolve dataset + model std deviation for normalization
65
    data_config['std'] = IMAGENET_DEFAULT_STD
66
    if args.get('std', None) is not None:
67
        std = tuple(args['std'])
68
        if len(std) == 1:
69
            std = tuple(list(std) * in_chans)
70
        else:
71
            assert len(std) == in_chans
72
        data_config['std'] = std
73
    elif pretrained_cfg.get('std', None):
74
        data_config['std'] = pretrained_cfg['std']
75

76
    # resolve default inference crop
77
    crop_pct = DEFAULT_CROP_PCT
78
    if args.get('crop_pct', None):
79
        crop_pct = args['crop_pct']
80
    else:
81
        if use_test_size and pretrained_cfg.get('test_crop_pct', None):
82
            crop_pct = pretrained_cfg['test_crop_pct']
83
        elif pretrained_cfg.get('crop_pct', None):
84
            crop_pct = pretrained_cfg['crop_pct']
85
    data_config['crop_pct'] = crop_pct
86

87
    # resolve default crop percentage
88
    crop_mode = DEFAULT_CROP_MODE
89
    if args.get('crop_mode', None):
90
        crop_mode = args['crop_mode']
91
    elif pretrained_cfg.get('crop_mode', None):
92
        crop_mode = pretrained_cfg['crop_mode']
93
    data_config['crop_mode'] = crop_mode
94

95
    if verbose:
96
        _logger.info('Data processing configuration for current model + dataset:')
97
        for n, v in data_config.items():
98
            _logger.info('\t%s: %s' % (n, str(v)))
99

100
    return data_config
101

102

103
def resolve_model_data_config(
104
        model,
105
        args=None,
106
        pretrained_cfg=None,
107
        use_test_size=False,
108
        verbose=False,
109
):
110
    """ Resolve Model Data Config
111
    This is equivalent to resolve_data_config() but with arguments re-ordered to put model first.
112

113
    Args:
114
        model (nn.Module): the model instance
115
        args (dict): command line arguments / configuration in dict form (overrides pretrained_cfg)
116
        pretrained_cfg (dict): pretrained model config (overrides pretrained_cfg attached to model)
117
        use_test_size (bool): use the test time input resolution (if one exists) instead of default train resolution
118
        verbose (bool): enable extra logging of resolved values
119

120
    Returns:
121
        dictionary of config
122
    """
123
    return resolve_data_config(
124
        args=args,
125
        pretrained_cfg=pretrained_cfg,
126
        model=model,
127
        use_test_size=use_test_size,
128
        verbose=verbose,
129
    )
130

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.