pytorch-image-models
/
inference.py
366 строк · 15.6 Кб
1#!/usr/bin/env python3
2"""PyTorch Inference Script
3
4An example inference script that outputs top-k class ids for images in a folder into a csv.
5
6Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
7"""
8import argparse
9import json
10import logging
11import os
12import time
13from contextlib import suppress
14from functools import partial
15
16import numpy as np
17import pandas as pd
18import torch
19
20from timm.data import create_dataset, create_loader, resolve_data_config, ImageNetInfo, infer_imagenet_subset
21from timm.layers import apply_test_time_pool
22from timm.models import create_model
23from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs
24
25try:
26from apex import amp
27has_apex = True
28except ImportError:
29has_apex = False
30
31has_native_amp = False
32try:
33if getattr(torch.cuda.amp, 'autocast') is not None:
34has_native_amp = True
35except AttributeError:
36pass
37
38try:
39from functorch.compile import memory_efficient_fusion
40has_functorch = True
41except ImportError as e:
42has_functorch = False
43
44has_compile = hasattr(torch, 'compile')
45
46
47_FMT_EXT = {
48'json': '.json',
49'json-record': '.json',
50'json-split': '.json',
51'parquet': '.parquet',
52'csv': '.csv',
53}
54
55torch.backends.cudnn.benchmark = True
56_logger = logging.getLogger('inference')
57
58
59parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference')
60parser.add_argument('data', nargs='?', metavar='DIR', const=None,
61help='path to dataset (*deprecated*, use --data-dir)')
62parser.add_argument('--data-dir', metavar='DIR',
63help='path to dataset (root dir)')
64parser.add_argument('--dataset', metavar='NAME', default='',
65help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
66parser.add_argument('--split', metavar='NAME', default='validation',
67help='dataset split (default: validation)')
68parser.add_argument('--model', '-m', metavar='MODEL', default='resnet50',
69help='model architecture (default: resnet50)')
70parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
71help='number of data loading workers (default: 2)')
72parser.add_argument('-b', '--batch-size', default=256, type=int,
73metavar='N', help='mini-batch size (default: 256)')
74parser.add_argument('--img-size', default=None, type=int,
75metavar='N', help='Input image dimension, uses model default if empty')
76parser.add_argument('--in-chans', type=int, default=None, metavar='N',
77help='Image input channels (default: None => 3)')
78parser.add_argument('--input-size', default=None, nargs=3, type=int,
79metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
80parser.add_argument('--use-train-size', action='store_true', default=False,
81help='force use of train input size, even when test size is specified in pretrained cfg')
82parser.add_argument('--crop-pct', default=None, type=float,
83metavar='N', help='Input image center crop pct')
84parser.add_argument('--crop-mode', default=None, type=str,
85metavar='N', help='Input image crop mode (squash, border, center). Model default if None.')
86parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
87help='Override mean pixel value of dataset')
88parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
89help='Override std deviation of of dataset')
90parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
91help='Image resize interpolation type (overrides model)')
92parser.add_argument('--num-classes', type=int, default=None,
93help='Number classes in dataset')
94parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
95help='path to class to idx mapping file (default: "")')
96parser.add_argument('--log-freq', default=10, type=int,
97metavar='N', help='batch logging frequency (default: 10)')
98parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
99help='path to latest checkpoint (default: none)')
100parser.add_argument('--pretrained', dest='pretrained', action='store_true',
101help='use pre-trained model')
102parser.add_argument('--num-gpu', type=int, default=1,
103help='Number of GPUS to use')
104parser.add_argument('--test-pool', dest='test_pool', action='store_true',
105help='enable test time pool')
106parser.add_argument('--channels-last', action='store_true', default=False,
107help='Use channels_last memory layout')
108parser.add_argument('--device', default='cuda', type=str,
109help="Device (accelerator) to use.")
110parser.add_argument('--amp', action='store_true', default=False,
111help='use Native AMP for mixed precision training')
112parser.add_argument('--amp-dtype', default='float16', type=str,
113help='lower precision AMP dtype (default: float16)')
114parser.add_argument('--fuser', default='', type=str,
115help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
116parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
117
118scripting_group = parser.add_mutually_exclusive_group()
119scripting_group.add_argument('--torchscript', default=False, action='store_true',
120help='torch.jit.script the full model')
121scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None, const='inductor',
122help="Enable compilation w/ specified backend (default: inductor).")
123scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
124help="Enable AOT Autograd support.")
125
126parser.add_argument('--results-dir', type=str, default=None,
127help='folder for output results')
128parser.add_argument('--results-file', type=str, default=None,
129help='results filename (relative to results-dir)')
130parser.add_argument('--results-format', type=str, nargs='+', default=['csv'],
131help='results format (one of "csv", "json", "json-split", "parquet")')
132parser.add_argument('--results-separate-col', action='store_true', default=False,
133help='separate output columns per result index.')
134parser.add_argument('--topk', default=1, type=int,
135metavar='N', help='Top-k to output to CSV')
136parser.add_argument('--fullname', action='store_true', default=False,
137help='use full sample name in output (not just basename).')
138parser.add_argument('--filename-col', type=str, default='filename',
139help='name for filename / sample name column')
140parser.add_argument('--index-col', type=str, default='index',
141help='name for output indices column(s)')
142parser.add_argument('--label-col', type=str, default='label',
143help='name for output indices column(s)')
144parser.add_argument('--output-col', type=str, default=None,
145help='name for logit/probs output column(s)')
146parser.add_argument('--output-type', type=str, default='prob',
147help='output type colum ("prob" for probabilities, "logit" for raw logits)')
148parser.add_argument('--label-type', type=str, default='description',
149help='type of label to output, one of "none", "name", "description", "detailed"')
150parser.add_argument('--include-index', action='store_true', default=False,
151help='include the class index in results')
152parser.add_argument('--exclude-output', action='store_true', default=False,
153help='exclude logits/probs from results, just indices. topk must be set !=0.')
154
155
156def main():
157setup_default_logging()
158args = parser.parse_args()
159# might as well try to do something useful...
160args.pretrained = args.pretrained or not args.checkpoint
161
162if torch.cuda.is_available():
163torch.backends.cuda.matmul.allow_tf32 = True
164torch.backends.cudnn.benchmark = True
165
166device = torch.device(args.device)
167
168# resolve AMP arguments based on PyTorch / Apex availability
169amp_autocast = suppress
170if args.amp:
171assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
172assert args.amp_dtype in ('float16', 'bfloat16')
173amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
174amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
175_logger.info('Running inference in mixed precision with native PyTorch AMP.')
176else:
177_logger.info('Running inference in float32. AMP not enabled.')
178
179if args.fuser:
180set_jit_fuser(args.fuser)
181
182# create model
183in_chans = 3
184if args.in_chans is not None:
185in_chans = args.in_chans
186elif args.input_size is not None:
187in_chans = args.input_size[0]
188
189model = create_model(
190args.model,
191num_classes=args.num_classes,
192in_chans=in_chans,
193pretrained=args.pretrained,
194checkpoint_path=args.checkpoint,
195**args.model_kwargs,
196)
197if args.num_classes is None:
198assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
199args.num_classes = model.num_classes
200
201_logger.info(
202f'Model {args.model} created, param count: {sum([m.numel() for m in model.parameters()])}')
203
204data_config = resolve_data_config(vars(args), model=model)
205test_time_pool = False
206if args.test_pool:
207model, test_time_pool = apply_test_time_pool(model, data_config)
208
209model = model.to(device)
210model.eval()
211if args.channels_last:
212model = model.to(memory_format=torch.channels_last)
213
214if args.torchscript:
215model = torch.jit.script(model)
216elif args.torchcompile:
217assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
218torch._dynamo.reset()
219model = torch.compile(model, backend=args.torchcompile)
220elif args.aot_autograd:
221assert has_functorch, "functorch is needed for --aot-autograd"
222model = memory_efficient_fusion(model)
223
224if args.num_gpu > 1:
225model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
226
227root_dir = args.data or args.data_dir
228dataset = create_dataset(
229root=root_dir,
230name=args.dataset,
231split=args.split,
232class_map=args.class_map,
233)
234
235if test_time_pool:
236data_config['crop_pct'] = 1.0
237
238workers = 1 if 'tfds' in args.dataset or 'wds' in args.dataset else args.workers
239loader = create_loader(
240dataset,
241batch_size=args.batch_size,
242use_prefetcher=True,
243num_workers=workers,
244device=device,
245**data_config,
246)
247
248to_label = None
249if args.label_type in ('name', 'description', 'detail'):
250imagenet_subset = infer_imagenet_subset(model)
251if imagenet_subset is not None:
252dataset_info = ImageNetInfo(imagenet_subset)
253if args.label_type == 'name':
254to_label = lambda x: dataset_info.index_to_label_name(x)
255elif args.label_type == 'detail':
256to_label = lambda x: dataset_info.index_to_description(x, detailed=True)
257else:
258to_label = lambda x: dataset_info.index_to_description(x)
259to_label = np.vectorize(to_label)
260else:
261_logger.error("Cannot deduce ImageNet subset from model, no labelling will be performed.")
262
263top_k = min(args.topk, args.num_classes)
264batch_time = AverageMeter()
265end = time.time()
266all_indices = []
267all_labels = []
268all_outputs = []
269use_probs = args.output_type == 'prob'
270with torch.no_grad():
271for batch_idx, (input, _) in enumerate(loader):
272
273with amp_autocast():
274output = model(input)
275
276if use_probs:
277output = output.softmax(-1)
278
279if top_k:
280output, indices = output.topk(top_k)
281np_indices = indices.cpu().numpy()
282if args.include_index:
283all_indices.append(np_indices)
284if to_label is not None:
285np_labels = to_label(np_indices)
286all_labels.append(np_labels)
287
288all_outputs.append(output.cpu().numpy())
289
290# measure elapsed time
291batch_time.update(time.time() - end)
292end = time.time()
293
294if batch_idx % args.log_freq == 0:
295_logger.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
296batch_idx, len(loader), batch_time=batch_time))
297
298all_indices = np.concatenate(all_indices, axis=0) if all_indices else None
299all_labels = np.concatenate(all_labels, axis=0) if all_labels else None
300all_outputs = np.concatenate(all_outputs, axis=0).astype(np.float32)
301filenames = loader.dataset.filenames(basename=not args.fullname)
302
303output_col = args.output_col or ('prob' if use_probs else 'logit')
304data_dict = {args.filename_col: filenames}
305if args.results_separate_col and all_outputs.shape[-1] > 1:
306if all_indices is not None:
307for i in range(all_indices.shape[-1]):
308data_dict[f'{args.index_col}_{i}'] = all_indices[:, i]
309if all_labels is not None:
310for i in range(all_labels.shape[-1]):
311data_dict[f'{args.label_col}_{i}'] = all_labels[:, i]
312for i in range(all_outputs.shape[-1]):
313data_dict[f'{output_col}_{i}'] = all_outputs[:, i]
314else:
315if all_indices is not None:
316if all_indices.shape[-1] == 1:
317all_indices = all_indices.squeeze(-1)
318data_dict[args.index_col] = list(all_indices)
319if all_labels is not None:
320if all_labels.shape[-1] == 1:
321all_labels = all_labels.squeeze(-1)
322data_dict[args.label_col] = list(all_labels)
323if all_outputs.shape[-1] == 1:
324all_outputs = all_outputs.squeeze(-1)
325data_dict[output_col] = list(all_outputs)
326
327df = pd.DataFrame(data=data_dict)
328
329results_filename = args.results_file
330if results_filename:
331filename_no_ext, ext = os.path.splitext(results_filename)
332if ext and ext in _FMT_EXT.values():
333# if filename provided with one of expected ext,
334# remove it as it will be added back
335results_filename = filename_no_ext
336else:
337# base default filename on model name + img-size
338img_size = data_config["input_size"][1]
339results_filename = f'{args.model}-{img_size}'
340
341if args.results_dir:
342results_filename = os.path.join(args.results_dir, results_filename)
343
344for fmt in args.results_format:
345save_results(df, results_filename, fmt)
346
347print(f'--result')
348print(df.set_index(args.filename_col).to_json(orient='index', indent=4))
349
350
351def save_results(df, results_filename, results_format='csv', filename_col='filename'):
352results_filename += _FMT_EXT[results_format]
353if results_format == 'parquet':
354df.set_index(filename_col).to_parquet(results_filename)
355elif results_format == 'json':
356df.set_index(filename_col).to_json(results_filename, indent=4, orient='index')
357elif results_format == 'json-records':
358df.to_json(results_filename, lines=True, orient='records')
359elif results_format == 'json-split':
360df.to_json(results_filename, indent=4, orient='split', index=False)
361else:
362df.to_csv(results_filename, index=False)
363
364
365if __name__ == '__main__':
366main()
367