caffe

Форк
0
/
detect.py 
173 строки · 5.6 Кб
1
#!/usr/bin/env python
2
"""
3
detector.py is an out-of-the-box windowed detector
4
callable from the command line.
5

6
By default it configures and runs the Caffe reference ImageNet model.
7
Note that this model was trained for image classification and not detection,
8
and finetuning for detection can be expected to improve results.
9

10
The selective_search_ijcv_with_python code required for the selective search
11
proposal mode is available at
12
    https://github.com/sergeyk/selective_search_ijcv_with_python
13

14
TODO:
15
- batch up image filenames as well: don't want to load all of them into memory
16
- come up with a batching scheme that preserved order / keeps a unique ID
17
"""
18
import numpy as np
19
import pandas as pd
20
import os
21
import argparse
22
import time
23

24
import caffe
25

26
CROP_MODES = ['list', 'selective_search']
27
COORD_COLS = ['ymin', 'xmin', 'ymax', 'xmax']
28

29

30
def main(argv):
31
    pycaffe_dir = os.path.dirname(__file__)
32

33
    parser = argparse.ArgumentParser()
34
    # Required arguments: input and output.
35
    parser.add_argument(
36
        "input_file",
37
        help="Input txt/csv filename. If .txt, must be list of filenames.\
38
        If .csv, must be comma-separated file with header\
39
        'filename, xmin, ymin, xmax, ymax'"
40
    )
41
    parser.add_argument(
42
        "output_file",
43
        help="Output h5/csv filename. Format depends on extension."
44
    )
45
    # Optional arguments.
46
    parser.add_argument(
47
        "--model_def",
48
        default=os.path.join(pycaffe_dir,
49
                "../models/bvlc_reference_caffenet/deploy.prototxt"),
50
        help="Model definition file."
51
    )
52
    parser.add_argument(
53
        "--pretrained_model",
54
        default=os.path.join(pycaffe_dir,
55
                "../models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel"),
56
        help="Trained model weights file."
57
    )
58
    parser.add_argument(
59
        "--crop_mode",
60
        default="selective_search",
61
        choices=CROP_MODES,
62
        help="How to generate windows for detection."
63
    )
64
    parser.add_argument(
65
        "--gpu",
66
        action='store_true',
67
        help="Switch for gpu computation."
68
    )
69
    parser.add_argument(
70
        "--mean_file",
71
        default=os.path.join(pycaffe_dir,
72
                             'caffe/imagenet/ilsvrc_2012_mean.npy'),
73
        help="Data set image mean of H x W x K dimensions (numpy array). " +
74
             "Set to '' for no mean subtraction."
75
    )
76
    parser.add_argument(
77
        "--input_scale",
78
        type=float,
79
        help="Multiply input features by this scale to finish preprocessing."
80
    )
81
    parser.add_argument(
82
        "--raw_scale",
83
        type=float,
84
        default=255.0,
85
        help="Multiply raw input by this scale before preprocessing."
86
    )
87
    parser.add_argument(
88
        "--channel_swap",
89
        default='2,1,0',
90
        help="Order to permute input channels. The default converts " +
91
             "RGB -> BGR since BGR is the Caffe default by way of OpenCV."
92

93
    )
94
    parser.add_argument(
95
        "--context_pad",
96
        type=int,
97
        default='16',
98
        help="Amount of surrounding context to collect in input window."
99
    )
100
    args = parser.parse_args()
101

102
    mean, channel_swap = None, None
103
    if args.mean_file:
104
        mean = np.load(args.mean_file)
105
        if mean.shape[1:] != (1, 1):
106
            mean = mean.mean(1).mean(1)
107
    if args.channel_swap:
108
        channel_swap = [int(s) for s in args.channel_swap.split(',')]
109

110
    if args.gpu:
111
        caffe.set_mode_gpu()
112
        print("GPU mode")
113
    else:
114
        caffe.set_mode_cpu()
115
        print("CPU mode")
116

117
    # Make detector.
118
    detector = caffe.Detector(args.model_def, args.pretrained_model, mean=mean,
119
            input_scale=args.input_scale, raw_scale=args.raw_scale,
120
            channel_swap=channel_swap,
121
            context_pad=args.context_pad)
122

123
    # Load input.
124
    t = time.time()
125
    print("Loading input...")
126
    if args.input_file.lower().endswith('txt'):
127
        with open(args.input_file) as f:
128
            inputs = [_.strip() for _ in f.readlines()]
129
    elif args.input_file.lower().endswith('csv'):
130
        inputs = pd.read_csv(args.input_file, sep=',', dtype={'filename': str})
131
        inputs.set_index('filename', inplace=True)
132
    else:
133
        raise Exception("Unknown input file type: not in txt or csv.")
134

135
    # Detect.
136
    if args.crop_mode == 'list':
137
        # Unpack sequence of (image filename, windows).
138
        images_windows = [
139
            (ix, inputs.iloc[np.where(inputs.index == ix)][COORD_COLS].values)
140
            for ix in inputs.index.unique()
141
        ]
142
        detections = detector.detect_windows(images_windows)
143
    else:
144
        detections = detector.detect_selective_search(inputs)
145
    print("Processed {} windows in {:.3f} s.".format(len(detections),
146
                                                     time.time() - t))
147

148
    # Collect into dataframe with labeled fields.
149
    df = pd.DataFrame(detections)
150
    df.set_index('filename', inplace=True)
151
    df[COORD_COLS] = pd.DataFrame(
152
        data=np.vstack(df['window']), index=df.index, columns=COORD_COLS)
153
    del(df['window'])
154

155
    # Save results.
156
    t = time.time()
157
    if args.output_file.lower().endswith('csv'):
158
        # csv
159
        # Enumerate the class probabilities.
160
        class_cols = ['class{}'.format(x) for x in range(NUM_OUTPUT)]
161
        df[class_cols] = pd.DataFrame(
162
            data=np.vstack(df['feat']), index=df.index, columns=class_cols)
163
        df.to_csv(args.output_file, cols=COORD_COLS + class_cols)
164
    else:
165
        # h5
166
        df.to_hdf(args.output_file, 'df', mode='w')
167
    print("Saved to {} in {:.3f} s.".format(args.output_file,
168
                                            time.time() - t))
169

170

171
if __name__ == "__main__":
172
    import sys
173
    main(sys.argv)
174

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.