pytorch-image-models

Форк
0
/
inference.py 
366 строк · 15.6 Кб
1
#!/usr/bin/env python3
2
"""PyTorch Inference Script
3

4
An example inference script that outputs top-k class ids for images in a folder into a csv.
5

6
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
7
"""
8
import argparse
9
import json
10
import logging
11
import os
12
import time
13
from contextlib import suppress
14
from functools import partial
15

16
import numpy as np
17
import pandas as pd
18
import torch
19

20
from timm.data import create_dataset, create_loader, resolve_data_config, ImageNetInfo, infer_imagenet_subset
21
from timm.layers import apply_test_time_pool
22
from timm.models import create_model
23
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs
24

25
try:
26
    from apex import amp
27
    has_apex = True
28
except ImportError:
29
    has_apex = False
30

31
has_native_amp = False
32
try:
33
    if getattr(torch.cuda.amp, 'autocast') is not None:
34
        has_native_amp = True
35
except AttributeError:
36
    pass
37

38
try:
39
    from functorch.compile import memory_efficient_fusion
40
    has_functorch = True
41
except ImportError as e:
42
    has_functorch = False
43

44
has_compile = hasattr(torch, 'compile')
45

46

47
_FMT_EXT = {
48
    'json': '.json',
49
    'json-record': '.json',
50
    'json-split': '.json',
51
    'parquet': '.parquet',
52
    'csv': '.csv',
53
}
54

55
torch.backends.cudnn.benchmark = True
56
_logger = logging.getLogger('inference')
57

58

59
parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference')
60
parser.add_argument('data', nargs='?', metavar='DIR', const=None,
61
                    help='path to dataset (*deprecated*, use --data-dir)')
62
parser.add_argument('--data-dir', metavar='DIR',
63
                    help='path to dataset (root dir)')
64
parser.add_argument('--dataset', metavar='NAME', default='',
65
                    help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
66
parser.add_argument('--split', metavar='NAME', default='validation',
67
                    help='dataset split (default: validation)')
68
parser.add_argument('--model', '-m', metavar='MODEL', default='resnet50',
69
                    help='model architecture (default: resnet50)')
70
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
71
                    help='number of data loading workers (default: 2)')
72
parser.add_argument('-b', '--batch-size', default=256, type=int,
73
                    metavar='N', help='mini-batch size (default: 256)')
74
parser.add_argument('--img-size', default=None, type=int,
75
                    metavar='N', help='Input image dimension, uses model default if empty')
76
parser.add_argument('--in-chans', type=int, default=None, metavar='N',
77
                    help='Image input channels (default: None => 3)')
78
parser.add_argument('--input-size', default=None, nargs=3, type=int,
79
                    metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
80
parser.add_argument('--use-train-size', action='store_true', default=False,
81
                    help='force use of train input size, even when test size is specified in pretrained cfg')
82
parser.add_argument('--crop-pct', default=None, type=float,
83
                    metavar='N', help='Input image center crop pct')
84
parser.add_argument('--crop-mode', default=None, type=str,
85
                    metavar='N', help='Input image crop mode (squash, border, center). Model default if None.')
86
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
87
                    help='Override mean pixel value of dataset')
88
parser.add_argument('--std', type=float,  nargs='+', default=None, metavar='STD',
89
                    help='Override std deviation of of dataset')
90
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
91
                    help='Image resize interpolation type (overrides model)')
92
parser.add_argument('--num-classes', type=int, default=None,
93
                    help='Number classes in dataset')
94
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
95
                    help='path to class to idx mapping file (default: "")')
96
parser.add_argument('--log-freq', default=10, type=int,
97
                    metavar='N', help='batch logging frequency (default: 10)')
98
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
99
                    help='path to latest checkpoint (default: none)')
100
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
101
                    help='use pre-trained model')
102
parser.add_argument('--num-gpu', type=int, default=1,
103
                    help='Number of GPUS to use')
104
parser.add_argument('--test-pool', dest='test_pool', action='store_true',
105
                    help='enable test time pool')
106
parser.add_argument('--channels-last', action='store_true', default=False,
107
                    help='Use channels_last memory layout')
108
parser.add_argument('--device', default='cuda', type=str,
109
                    help="Device (accelerator) to use.")
110
parser.add_argument('--amp', action='store_true', default=False,
111
                    help='use Native AMP for mixed precision training')
112
parser.add_argument('--amp-dtype', default='float16', type=str,
113
                    help='lower precision AMP dtype (default: float16)')
114
parser.add_argument('--fuser', default='', type=str,
115
                    help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
116
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
117

118
scripting_group = parser.add_mutually_exclusive_group()
119
scripting_group.add_argument('--torchscript', default=False, action='store_true',
120
                             help='torch.jit.script the full model')
121
scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
122
                             help="Enable compilation w/ specified backend (default: inductor).")
123
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
124
                             help="Enable AOT Autograd support.")
125

126
parser.add_argument('--results-dir', type=str, default=None,
127
                    help='folder for output results')
128
parser.add_argument('--results-file', type=str, default=None,
129
                    help='results filename (relative to results-dir)')
130
parser.add_argument('--results-format', type=str, nargs='+', default=['csv'],
131
                    help='results format (one of "csv", "json", "json-split", "parquet")')
132
parser.add_argument('--results-separate-col', action='store_true', default=False,
133
                    help='separate output columns per result index.')
134
parser.add_argument('--topk', default=1, type=int,
135
                    metavar='N', help='Top-k to output to CSV')
136
parser.add_argument('--fullname', action='store_true', default=False,
137
                    help='use full sample name in output (not just basename).')
138
parser.add_argument('--filename-col', type=str, default='filename',
139
                    help='name for filename / sample name column')
140
parser.add_argument('--index-col', type=str, default='index',
141
                    help='name for output indices column(s)')
142
parser.add_argument('--label-col', type=str, default='label',
143
                    help='name for output indices column(s)')
144
parser.add_argument('--output-col', type=str, default=None,
145
                    help='name for logit/probs output column(s)')
146
parser.add_argument('--output-type', type=str, default='prob',
147
                    help='output type colum ("prob" for probabilities, "logit" for raw logits)')
148
parser.add_argument('--label-type', type=str, default='description',
149
                    help='type of label to output, one of  "none", "name", "description", "detailed"')
150
parser.add_argument('--include-index', action='store_true', default=False,
151
                    help='include the class index in results')
152
parser.add_argument('--exclude-output', action='store_true', default=False,
153
                    help='exclude logits/probs from results, just indices. topk must be set !=0.')
154

155

156
def main():
157
    setup_default_logging()
158
    args = parser.parse_args()
159
    # might as well try to do something useful...
160
    args.pretrained = args.pretrained or not args.checkpoint
161

162
    if torch.cuda.is_available():
163
        torch.backends.cuda.matmul.allow_tf32 = True
164
        torch.backends.cudnn.benchmark = True
165

166
    device = torch.device(args.device)
167

168
    # resolve AMP arguments based on PyTorch / Apex availability
169
    amp_autocast = suppress
170
    if args.amp:
171
        assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
172
        assert args.amp_dtype in ('float16', 'bfloat16')
173
        amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
174
        amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
175
        _logger.info('Running inference in mixed precision with native PyTorch AMP.')
176
    else:
177
        _logger.info('Running inference in float32. AMP not enabled.')
178

179
    if args.fuser:
180
        set_jit_fuser(args.fuser)
181

182
    # create model
183
    in_chans = 3
184
    if args.in_chans is not None:
185
        in_chans = args.in_chans
186
    elif args.input_size is not None:
187
        in_chans = args.input_size[0]
188

189
    model = create_model(
190
        args.model,
191
        num_classes=args.num_classes,
192
        in_chans=in_chans,
193
        pretrained=args.pretrained,
194
        checkpoint_path=args.checkpoint,
195
        **args.model_kwargs,
196
    )
197
    if args.num_classes is None:
198
        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
199
        args.num_classes = model.num_classes
200

201
    _logger.info(
202
        f'Model {args.model} created, param count: {sum([m.numel() for m in model.parameters()])}')
203

204
    data_config = resolve_data_config(vars(args), model=model)
205
    test_time_pool = False
206
    if args.test_pool:
207
        model, test_time_pool = apply_test_time_pool(model, data_config)
208

209
    model = model.to(device)
210
    model.eval()
211
    if args.channels_last:
212
        model = model.to(memory_format=torch.channels_last)
213

214
    if args.torchscript:
215
        model = torch.jit.script(model)
216
    elif args.torchcompile:
217
        assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
218
        torch._dynamo.reset()
219
        model = torch.compile(model, backend=args.torchcompile)
220
    elif args.aot_autograd:
221
        assert has_functorch, "functorch is needed for --aot-autograd"
222
        model = memory_efficient_fusion(model)
223

224
    if args.num_gpu > 1:
225
        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
226

227
    root_dir = args.data or args.data_dir
228
    dataset = create_dataset(
229
        root=root_dir,
230
        name=args.dataset,
231
        split=args.split,
232
        class_map=args.class_map,
233
    )
234

235
    if test_time_pool:
236
        data_config['crop_pct'] = 1.0
237

238
    workers = 1 if 'tfds' in args.dataset or 'wds' in args.dataset else args.workers
239
    loader = create_loader(
240
        dataset,
241
        batch_size=args.batch_size,
242
        use_prefetcher=True,
243
        num_workers=workers,
244
        device=device,
245
        **data_config,
246
    )
247

248
    to_label = None
249
    if args.label_type in ('name', 'description', 'detail'):
250
        imagenet_subset = infer_imagenet_subset(model)
251
        if imagenet_subset is not None:
252
            dataset_info = ImageNetInfo(imagenet_subset)
253
            if args.label_type == 'name':
254
                to_label = lambda x: dataset_info.index_to_label_name(x)
255
            elif args.label_type == 'detail':
256
                to_label = lambda x: dataset_info.index_to_description(x, detailed=True)
257
            else:
258
                to_label = lambda x: dataset_info.index_to_description(x)
259
            to_label = np.vectorize(to_label)
260
        else:
261
            _logger.error("Cannot deduce ImageNet subset from model, no labelling will be performed.")
262

263
    top_k = min(args.topk, args.num_classes)
264
    batch_time = AverageMeter()
265
    end = time.time()
266
    all_indices = []
267
    all_labels = []
268
    all_outputs = []
269
    use_probs = args.output_type == 'prob'
270
    with torch.no_grad():
271
        for batch_idx, (input, _) in enumerate(loader):
272

273
            with amp_autocast():
274
                output = model(input)
275

276
            if use_probs:
277
                output = output.softmax(-1)
278

279
            if top_k:
280
                output, indices = output.topk(top_k)
281
                np_indices = indices.cpu().numpy()
282
                if args.include_index:
283
                    all_indices.append(np_indices)
284
                if to_label is not None:
285
                    np_labels = to_label(np_indices)
286
                    all_labels.append(np_labels)
287

288
            all_outputs.append(output.cpu().numpy())
289

290
            # measure elapsed time
291
            batch_time.update(time.time() - end)
292
            end = time.time()
293

294
            if batch_idx % args.log_freq == 0:
295
                _logger.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
296
                    batch_idx, len(loader), batch_time=batch_time))
297

298
    all_indices = np.concatenate(all_indices, axis=0) if all_indices else None
299
    all_labels = np.concatenate(all_labels, axis=0) if all_labels else None
300
    all_outputs = np.concatenate(all_outputs, axis=0).astype(np.float32)
301
    filenames = loader.dataset.filenames(basename=not args.fullname)
302

303
    output_col = args.output_col or ('prob' if use_probs else 'logit')
304
    data_dict = {args.filename_col: filenames}
305
    if args.results_separate_col and all_outputs.shape[-1] > 1:
306
        if all_indices is not None:
307
            for i in range(all_indices.shape[-1]):
308
                data_dict[f'{args.index_col}_{i}'] = all_indices[:, i]
309
        if all_labels is not None:
310
            for i in range(all_labels.shape[-1]):
311
                data_dict[f'{args.label_col}_{i}'] = all_labels[:, i]
312
        for i in range(all_outputs.shape[-1]):
313
            data_dict[f'{output_col}_{i}'] = all_outputs[:, i]
314
    else:
315
        if all_indices is not None:
316
            if all_indices.shape[-1] == 1:
317
                all_indices = all_indices.squeeze(-1)
318
            data_dict[args.index_col] = list(all_indices)
319
        if all_labels is not None:
320
            if all_labels.shape[-1] == 1:
321
                all_labels = all_labels.squeeze(-1)
322
            data_dict[args.label_col] = list(all_labels)
323
        if all_outputs.shape[-1] == 1:
324
            all_outputs = all_outputs.squeeze(-1)
325
        data_dict[output_col] = list(all_outputs)
326

327
    df = pd.DataFrame(data=data_dict)
328

329
    results_filename = args.results_file
330
    if results_filename:
331
        filename_no_ext, ext = os.path.splitext(results_filename)
332
        if ext and ext in _FMT_EXT.values():
333
            # if filename provided with one of expected ext,
334
            # remove it as it will be added back
335
            results_filename = filename_no_ext
336
    else:
337
        # base default filename on model name + img-size
338
        img_size = data_config["input_size"][1]
339
        results_filename = f'{args.model}-{img_size}'
340

341
    if args.results_dir:
342
        results_filename = os.path.join(args.results_dir, results_filename)
343

344
    for fmt in args.results_format:
345
        save_results(df, results_filename, fmt)
346

347
    print(f'--result')
348
    print(df.set_index(args.filename_col).to_json(orient='index', indent=4))
349

350

351
def save_results(df, results_filename, results_format='csv', filename_col='filename'):
352
    results_filename += _FMT_EXT[results_format]
353
    if results_format == 'parquet':
354
        df.set_index(filename_col).to_parquet(results_filename)
355
    elif results_format == 'json':
356
        df.set_index(filename_col).to_json(results_filename, indent=4, orient='index')
357
    elif results_format == 'json-records':
358
        df.to_json(results_filename, lines=True, orient='records')
359
    elif results_format == 'json-split':
360
        df.to_json(results_filename, indent=4, orient='split', index=False)
361
    else:
362
        df.to_csv(results_filename, index=False)
363

364

365
if __name__ == '__main__':
366
    main()
367

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.