caffe

Форк
0
/
summarize.py 
140 строк · 4.8 Кб
1
#!/usr/bin/env python
2

3
"""Net summarization tool.
4

5
This tool summarizes the structure of a net in a concise but comprehensive
6
tabular listing, taking a prototxt file as input.
7

8
Use this tool to check at a glance that the computation you've specified is the
9
computation you expect.
10
"""
11

12
from caffe.proto import caffe_pb2
13
from google import protobuf
14
import re
15
import argparse
16

17
# ANSI codes for coloring blobs (used cyclically)
18
COLORS = ['92', '93', '94', '95', '97', '96', '42', '43;30', '100',
19
          '444', '103;30', '107;30']
20
DISCONNECTED_COLOR = '41'
21

22
def read_net(filename):
23
    net = caffe_pb2.NetParameter()
24
    with open(filename) as f:
25
        protobuf.text_format.Parse(f.read(), net)
26
    return net
27

28
def format_param(param):
29
    out = []
30
    if len(param.name) > 0:
31
        out.append(param.name)
32
    if param.lr_mult != 1:
33
        out.append('x{}'.format(param.lr_mult))
34
    if param.decay_mult != 1:
35
        out.append('Dx{}'.format(param.decay_mult))
36
    return ' '.join(out)
37

38
def printed_len(s):
39
    return len(re.sub(r'\033\[[\d;]+m', '', s))
40

41
def print_table(table, max_width):
42
    """Print a simple nicely-aligned table.
43

44
    table must be a list of (equal-length) lists. Columns are space-separated,
45
    and as narrow as possible, but no wider than max_width. Text may overflow
46
    columns; note that unlike string.format, this will not affect subsequent
47
    columns, if possible."""
48

49
    max_widths = [max_width] * len(table[0])
50
    column_widths = [max(printed_len(row[j]) + 1 for row in table)
51
                     for j in range(len(table[0]))]
52
    column_widths = [min(w, max_w) for w, max_w in zip(column_widths, max_widths)]
53

54
    for row in table:
55
        row_str = ''
56
        right_col = 0
57
        for cell, width in zip(row, column_widths):
58
            right_col += width
59
            row_str += cell + ' '
60
            row_str += ' ' * max(right_col - printed_len(row_str), 0)
61
        print row_str
62

63
def summarize_net(net):
64
    disconnected_tops = set()
65
    for lr in net.layer:
66
        disconnected_tops |= set(lr.top)
67
        disconnected_tops -= set(lr.bottom)
68

69
    table = []
70
    colors = {}
71
    for lr in net.layer:
72
        tops = []
73
        for ind, top in enumerate(lr.top):
74
            color = colors.setdefault(top, COLORS[len(colors) % len(COLORS)])
75
            if top in disconnected_tops:
76
                top = '\033[1;4m' + top
77
            if len(lr.loss_weight) > 0:
78
                top = '{} * {}'.format(lr.loss_weight[ind], top)
79
            tops.append('\033[{}m{}\033[0m'.format(color, top))
80
        top_str = ', '.join(tops)
81

82
        bottoms = []
83
        for bottom in lr.bottom:
84
            color = colors.get(bottom, DISCONNECTED_COLOR)
85
            bottoms.append('\033[{}m{}\033[0m'.format(color, bottom))
86
        bottom_str = ', '.join(bottoms)
87

88
        if lr.type == 'Python':
89
            type_str = lr.python_param.module + '.' + lr.python_param.layer
90
        else:
91
            type_str = lr.type
92

93
        # Summarize conv/pool parameters.
94
        # TODO support rectangular/ND parameters
95
        conv_param = lr.convolution_param
96
        if (lr.type in ['Convolution', 'Deconvolution']
97
                and len(conv_param.kernel_size) == 1):
98
            arg_str = str(conv_param.kernel_size[0])
99
            if len(conv_param.stride) > 0 and conv_param.stride[0] != 1:
100
                arg_str += '/' + str(conv_param.stride[0])
101
            if len(conv_param.pad) > 0 and conv_param.pad[0] != 0:
102
                arg_str += '+' + str(conv_param.pad[0])
103
            arg_str += ' ' + str(conv_param.num_output)
104
            if conv_param.group != 1:
105
                arg_str += '/' + str(conv_param.group)
106
        elif lr.type == 'Pooling':
107
            arg_str = str(lr.pooling_param.kernel_size)
108
            if lr.pooling_param.stride != 1:
109
                arg_str += '/' + str(lr.pooling_param.stride)
110
            if lr.pooling_param.pad != 0:
111
                arg_str += '+' + str(lr.pooling_param.pad)
112
        else:
113
            arg_str = ''
114

115
        if len(lr.param) > 0:
116
            param_strs = map(format_param, lr.param)
117
            if max(map(len, param_strs)) > 0:
118
                param_str = '({})'.format(', '.join(param_strs))
119
            else:
120
                param_str = ''
121
        else:
122
            param_str = ''
123

124
        table.append([lr.name, type_str, param_str, bottom_str, '->', top_str,
125
                      arg_str])
126
    return table
127

128
def main():
129
    parser = argparse.ArgumentParser(description="Print a concise summary of net computation.")
130
    parser.add_argument('filename', help='net prototxt file to summarize')
131
    parser.add_argument('-w', '--max-width', help='maximum field width',
132
            type=int, default=30)
133
    args = parser.parse_args()
134

135
    net = read_net(args.filename)
136
    table = summarize_net(net)
137
    print_table(table, max_width=args.max_width)
138

139
if __name__ == '__main__':
140
    main()
141

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.