caffe

Форк
0
/
classify.py 
138 строк · 4.2 Кб
1
#!/usr/bin/env python
2
"""
3
classify.py is an out-of-the-box image classifer callable from the command line.
4

5
By default it configures and runs the Caffe reference ImageNet model.
6
"""
7
import numpy as np
8
import os
9
import sys
10
import argparse
11
import glob
12
import time
13

14
import caffe
15

16

17
def main(argv):
18
    pycaffe_dir = os.path.dirname(__file__)
19

20
    parser = argparse.ArgumentParser()
21
    # Required arguments: input and output files.
22
    parser.add_argument(
23
        "input_file",
24
        help="Input image, directory, or npy."
25
    )
26
    parser.add_argument(
27
        "output_file",
28
        help="Output npy filename."
29
    )
30
    # Optional arguments.
31
    parser.add_argument(
32
        "--model_def",
33
        default=os.path.join(pycaffe_dir,
34
                "../models/bvlc_reference_caffenet/deploy.prototxt"),
35
        help="Model definition file."
36
    )
37
    parser.add_argument(
38
        "--pretrained_model",
39
        default=os.path.join(pycaffe_dir,
40
                "../models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel"),
41
        help="Trained model weights file."
42
    )
43
    parser.add_argument(
44
        "--gpu",
45
        action='store_true',
46
        help="Switch for gpu computation."
47
    )
48
    parser.add_argument(
49
        "--center_only",
50
        action='store_true',
51
        help="Switch for prediction from center crop alone instead of " +
52
             "averaging predictions across crops (default)."
53
    )
54
    parser.add_argument(
55
        "--images_dim",
56
        default='256,256',
57
        help="Canonical 'height,width' dimensions of input images."
58
    )
59
    parser.add_argument(
60
        "--mean_file",
61
        default=os.path.join(pycaffe_dir,
62
                             'caffe/imagenet/ilsvrc_2012_mean.npy'),
63
        help="Data set image mean of [Channels x Height x Width] dimensions " +
64
             "(numpy array). Set to '' for no mean subtraction."
65
    )
66
    parser.add_argument(
67
        "--input_scale",
68
        type=float,
69
        help="Multiply input features by this scale to finish preprocessing."
70
    )
71
    parser.add_argument(
72
        "--raw_scale",
73
        type=float,
74
        default=255.0,
75
        help="Multiply raw input by this scale before preprocessing."
76
    )
77
    parser.add_argument(
78
        "--channel_swap",
79
        default='2,1,0',
80
        help="Order to permute input channels. The default converts " +
81
             "RGB -> BGR since BGR is the Caffe default by way of OpenCV."
82
    )
83
    parser.add_argument(
84
        "--ext",
85
        default='jpg',
86
        help="Image file extension to take as input when a directory " +
87
             "is given as the input file."
88
    )
89
    args = parser.parse_args()
90

91
    image_dims = [int(s) for s in args.images_dim.split(',')]
92

93
    mean, channel_swap = None, None
94
    if args.mean_file:
95
        mean = np.load(args.mean_file)
96
    if args.channel_swap:
97
        channel_swap = [int(s) for s in args.channel_swap.split(',')]
98

99
    if args.gpu:
100
        caffe.set_mode_gpu()
101
        print("GPU mode")
102
    else:
103
        caffe.set_mode_cpu()
104
        print("CPU mode")
105

106
    # Make classifier.
107
    classifier = caffe.Classifier(args.model_def, args.pretrained_model,
108
            image_dims=image_dims, mean=mean,
109
            input_scale=args.input_scale, raw_scale=args.raw_scale,
110
            channel_swap=channel_swap)
111

112
    # Load numpy array (.npy), directory glob (*.jpg), or image file.
113
    args.input_file = os.path.expanduser(args.input_file)
114
    if args.input_file.endswith('npy'):
115
        print("Loading file: %s" % args.input_file)
116
        inputs = np.load(args.input_file)
117
    elif os.path.isdir(args.input_file):
118
        print("Loading folder: %s" % args.input_file)
119
        inputs =[caffe.io.load_image(im_f)
120
                 for im_f in glob.glob(args.input_file + '/*.' + args.ext)]
121
    else:
122
        print("Loading file: %s" % args.input_file)
123
        inputs = [caffe.io.load_image(args.input_file)]
124

125
    print("Classifying %d inputs." % len(inputs))
126

127
    # Classify.
128
    start = time.time()
129
    predictions = classifier.predict(inputs, not args.center_only)
130
    print("Done in %.2f s." % (time.time() - start))
131

132
    # Save
133
    print("Saving results into %s" % args.output_file)
134
    np.save(args.output_file, predictions)
135

136

137
if __name__ == '__main__':
138
    main(sys.argv)
139

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.