pytorch

Форк
0
/
net_drawer.py 
411 строк · 13.9 Кб
1
## @package net_drawer
2
# Module caffe2.python.net_drawer
3

4

5

6

7
import argparse
8
import json
9
import logging
10
from collections import defaultdict
11
from caffe2.python import utils
12

13
logger = logging.getLogger(__name__)
14
logger.setLevel(logging.INFO)
15

16
try:
17
    import pydot
18
except ImportError:
19
    logger.info(
20
        'Cannot import pydot, which is required for drawing a network. This '
21
        'can usually be installed in python with "pip install pydot". Also, '
22
        'pydot requires graphviz to convert dot files to pdf: in ubuntu, this '
23
        'can usually be installed with "sudo apt-get install graphviz".'
24
    )
25
    print(
26
        'net_drawer will not run correctly. Please install the correct '
27
        'dependencies.'
28
    )
29
    pydot = None
30

31
from caffe2.proto import caffe2_pb2
32

33
OP_STYLE = {
34
    'shape': 'box',
35
    'color': '#0F9D58',
36
    'style': 'filled',
37
    'fontcolor': '#FFFFFF'
38
}
39
BLOB_STYLE = {'shape': 'octagon'}
40

41

42
def _rectify_operator_and_name(operators_or_net, name):
43
    """Gets the operators and name for the pydot graph."""
44
    if isinstance(operators_or_net, caffe2_pb2.NetDef):
45
        operators = operators_or_net.op
46
        if name is None:
47
            name = operators_or_net.name
48
    elif hasattr(operators_or_net, 'Proto'):
49
        net = operators_or_net.Proto()
50
        if not isinstance(net, caffe2_pb2.NetDef):
51
            raise RuntimeError(
52
                "Expecting NetDef, but got {}".format(type(net)))
53
        operators = net.op
54
        if name is None:
55
            name = net.name
56
    else:
57
        operators = operators_or_net
58
        if name is None:
59
            name = "unnamed"
60
    return operators, name
61

62

63
def _escape_label(name):
64
    # json.dumps is poor man's escaping
65
    return json.dumps(name)
66

67

68
def GetOpNodeProducer(append_output, **kwargs):
69
    def ReallyGetOpNode(op, op_id):
70
        if op.name:
71
            node_name = '%s/%s (op#%d)' % (op.name, op.type, op_id)
72
        else:
73
            node_name = '%s (op#%d)' % (op.type, op_id)
74
        if append_output:
75
            for output_name in op.output:
76
                node_name += '\n' + output_name
77
        return pydot.Node(node_name, **kwargs)
78
    return ReallyGetOpNode
79

80

81
def GetBlobNodeProducer(**kwargs):
82
    def ReallyGetBlobNode(node_name, label):
83
        return pydot.Node(node_name, label=label, **kwargs)
84
    return ReallyGetBlobNode
85

86
def GetPydotGraph(
87
    operators_or_net,
88
    name=None,
89
    rankdir='LR',
90
    op_node_producer=None,
91
    blob_node_producer=None
92
):
93
    if op_node_producer is None:
94
        op_node_producer = GetOpNodeProducer(False, **OP_STYLE)
95
    if blob_node_producer is None:
96
        blob_node_producer = GetBlobNodeProducer(**BLOB_STYLE)
97
    operators, name = _rectify_operator_and_name(operators_or_net, name)
98
    graph = pydot.Dot(name, rankdir=rankdir)
99
    pydot_nodes = {}
100
    pydot_node_counts = defaultdict(int)
101
    for op_id, op in enumerate(operators):
102
        op_node = op_node_producer(op, op_id)
103
        graph.add_node(op_node)
104
        # print 'Op: %s' % op.name
105
        # print 'inputs: %s' % str(op.input)
106
        # print 'outputs: %s' % str(op.output)
107
        for input_name in op.input:
108
            if input_name not in pydot_nodes:
109
                input_node = blob_node_producer(
110
                    _escape_label(
111
                        input_name + str(pydot_node_counts[input_name])),
112
                    label=_escape_label(input_name),
113
                )
114
                pydot_nodes[input_name] = input_node
115
            else:
116
                input_node = pydot_nodes[input_name]
117
            graph.add_node(input_node)
118
            graph.add_edge(pydot.Edge(input_node, op_node))
119
        for output_name in op.output:
120
            if output_name in pydot_nodes:
121
                # we are overwriting an existing blob. need to update the count.
122
                pydot_node_counts[output_name] += 1
123
            output_node = blob_node_producer(
124
                _escape_label(
125
                    output_name + str(pydot_node_counts[output_name])),
126
                label=_escape_label(output_name),
127
            )
128
            pydot_nodes[output_name] = output_node
129
            graph.add_node(output_node)
130
            graph.add_edge(pydot.Edge(op_node, output_node))
131
    return graph
132

133

134
def GetPydotGraphMinimal(
135
    operators_or_net,
136
    name=None,
137
    rankdir='LR',
138
    minimal_dependency=False,
139
    op_node_producer=None,
140
):
141
    """Different from GetPydotGraph, hide all blob nodes and only show op nodes.
142

143
    If minimal_dependency is set as well, for each op, we will only draw the
144
    edges to the minimal necessary ancestors. For example, if op c depends on
145
    op a and b, and op b depends on a, then only the edge b->c will be drawn
146
    because a->c will be implied.
147
    """
148
    if op_node_producer is None:
149
        op_node_producer = GetOpNodeProducer(False, **OP_STYLE)
150
    operators, name = _rectify_operator_and_name(operators_or_net, name)
151
    graph = pydot.Dot(name, rankdir=rankdir)
152
    # blob_parents maps each blob name to its generating op.
153
    blob_parents = {}
154
    # op_ancestry records the ancestors of each op.
155
    op_ancestry = defaultdict(set)
156
    for op_id, op in enumerate(operators):
157
        op_node = op_node_producer(op, op_id)
158
        graph.add_node(op_node)
159
        # Get parents, and set up op ancestry.
160
        parents = [
161
            blob_parents[input_name] for input_name in op.input
162
            if input_name in blob_parents
163
        ]
164
        op_ancestry[op_node].update(parents)
165
        for node in parents:
166
            op_ancestry[op_node].update(op_ancestry[node])
167
        if minimal_dependency:
168
            # only add nodes that do not have transitive ancestry
169
            for node in parents:
170
                if all(
171
                    [node not in op_ancestry[other_node]
172
                     for other_node in parents]
173
                ):
174
                    graph.add_edge(pydot.Edge(node, op_node))
175
        else:
176
            # Add all parents to the graph.
177
            for node in parents:
178
                graph.add_edge(pydot.Edge(node, op_node))
179
        # Update blob_parents to reflect that this op created the blobs.
180
        for output_name in op.output:
181
            blob_parents[output_name] = op_node
182
    return graph
183

184

185
def GetOperatorMapForPlan(plan_def):
186
    operator_map = {}
187
    for net_id, net in enumerate(plan_def.network):
188
        if net.HasField('name'):
189
            operator_map[plan_def.name + "_" + net.name] = net.op
190
        else:
191
            operator_map[plan_def.name + "_network_%d" % net_id] = net.op
192
    return operator_map
193

194

195
def _draw_nets(nets, g):
196
    nodes = []
197
    for i, net in enumerate(nets):
198
        nodes.append(pydot.Node(_escape_label(net)))
199
        g.add_node(nodes[-1])
200
        if i > 0:
201
            g.add_edge(pydot.Edge(nodes[-2], nodes[-1]))
202
    return nodes
203

204

205
def _draw_steps(steps, g, skip_step_edges=False):  # noqa
206
    kMaxParallelSteps = 3
207

208
    def get_label():
209
        label = [step.name + '\n']
210
        if step.report_net:
211
            label.append('Reporter: {}'.format(step.report_net))
212
        if step.should_stop_blob:
213
            label.append('Stopper: {}'.format(step.should_stop_blob))
214
        if step.concurrent_substeps:
215
            label.append('Concurrent')
216
        if step.only_once:
217
            label.append('Once')
218
        return '\n'.join(label)
219

220
    def substep_edge(start, end):
221
        return pydot.Edge(start, end, arrowhead='dot', style='dashed')
222

223
    nodes = []
224
    for i, step in enumerate(steps):
225
        parallel = step.concurrent_substeps
226

227
        nodes.append(pydot.Node(_escape_label(get_label()), **OP_STYLE))
228
        g.add_node(nodes[-1])
229

230
        if i > 0 and not skip_step_edges:
231
            g.add_edge(pydot.Edge(nodes[-2], nodes[-1]))
232

233
        if step.network:
234
            sub_nodes = _draw_nets(step.network, g)
235
        elif step.substep:
236
            if parallel:
237
                sub_nodes = _draw_steps(
238
                    step.substep[:kMaxParallelSteps], g, skip_step_edges=True)
239
            else:
240
                sub_nodes = _draw_steps(step.substep, g)
241
        else:
242
            raise ValueError('invalid step')
243

244
        if parallel:
245
            for sn in sub_nodes:
246
                g.add_edge(substep_edge(nodes[-1], sn))
247
            if len(step.substep) > kMaxParallelSteps:
248
                ellipsis = pydot.Node('{} more steps'.format(
249
                    len(step.substep) - kMaxParallelSteps), **OP_STYLE)
250
                g.add_node(ellipsis)
251
                g.add_edge(substep_edge(nodes[-1], ellipsis))
252
        else:
253
            g.add_edge(substep_edge(nodes[-1], sub_nodes[0]))
254

255
    return nodes
256

257

258
def GetPlanGraph(plan_def, name=None, rankdir='TB'):
259
    graph = pydot.Dot(name, rankdir=rankdir)
260
    _draw_steps(plan_def.execution_step, graph)
261
    return graph
262

263

264
def GetGraphInJson(operators_or_net, output_filepath):
265
    operators, _ = _rectify_operator_and_name(operators_or_net, None)
266
    blob_strid_to_node_id = {}
267
    node_name_counts = defaultdict(int)
268
    nodes = []
269
    edges = []
270
    for op_id, op in enumerate(operators):
271
        op_label = op.name + '/' + op.type if op.name else op.type
272
        op_node_id = len(nodes)
273
        nodes.append({
274
            'id': op_node_id,
275
            'label': op_label,
276
            'op_id': op_id,
277
            'type': 'op'
278
        })
279
        for input_name in op.input:
280
            strid = _escape_label(
281
                input_name + str(node_name_counts[input_name]))
282
            if strid not in blob_strid_to_node_id:
283
                input_node = {
284
                    'id': len(nodes),
285
                    'label': input_name,
286
                    'type': 'blob'
287
                }
288
                blob_strid_to_node_id[strid] = len(nodes)
289
                nodes.append(input_node)
290
            else:
291
                input_node = nodes[blob_strid_to_node_id[strid]]
292
            edges.append({
293
                'source': blob_strid_to_node_id[strid],
294
                'target': op_node_id
295
            })
296
        for output_name in op.output:
297
            strid = _escape_label(
298
                output_name + str(node_name_counts[output_name]))
299
            if strid in blob_strid_to_node_id:
300
                # we are overwriting an existing blob. need to update the count.
301
                node_name_counts[output_name] += 1
302
                strid = _escape_label(
303
                    output_name + str(node_name_counts[output_name]))
304

305
            if strid not in blob_strid_to_node_id:
306
                output_node = {
307
                    'id': len(nodes),
308
                    'label': output_name,
309
                    'type': 'blob'
310
                }
311
                blob_strid_to_node_id[strid] = len(nodes)
312
                nodes.append(output_node)
313
            edges.append({
314
                'source': op_node_id,
315
                'target': blob_strid_to_node_id[strid]
316
            })
317

318
    with open(output_filepath, 'w') as f:
319
        json.dump({'nodes': nodes, 'edges': edges}, f)
320

321

322
# A dummy minimal PNG image used by GetGraphPngSafe as a
323
# placeholder when rendering fail to run.
324
_DummyPngImage = (
325
    b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00'
326
    b'\x01\x01\x00\x00\x00\x007n\xf9$\x00\x00\x00\nIDATx\x9cc`\x00\x00'
327
    b'\x00\x02\x00\x01H\xaf\xa4q\x00\x00\x00\x00IEND\xaeB`\x82')
328

329

330
def GetGraphPngSafe(func, *args, **kwargs):
331
    """
332
    Invokes `func` (e.g. GetPydotGraph) with args. If anything fails - returns
333
    and empty image instead of throwing Exception
334
    """
335
    try:
336
        graph = func(*args, **kwargs)
337
        if not isinstance(graph, pydot.Dot):
338
            raise ValueError("func is expected to return pydot.Dot")
339
        return graph.create_png()
340
    except Exception as e:
341
        logger.error("Failed to draw graph: {}".format(e))
342
        return _DummyPngImage
343

344

345
def main():
346
    parser = argparse.ArgumentParser(description="Caffe2 net drawer.")
347
    parser.add_argument(
348
        "--input",
349
        type=str, required=True,
350
        help="The input protobuf file."
351
    )
352
    parser.add_argument(
353
        "--output_prefix",
354
        type=str, default="",
355
        help="The prefix to be added to the output filename."
356
    )
357
    parser.add_argument(
358
        "--minimal", action="store_true",
359
        help="If set, produce a minimal visualization."
360
    )
361
    parser.add_argument(
362
        "--minimal_dependency", action="store_true",
363
        help="If set, only draw minimal dependency."
364
    )
365
    parser.add_argument(
366
        "--append_output", action="store_true",
367
        help="If set, append the output blobs to the operator names.")
368
    parser.add_argument(
369
        "--rankdir", type=str, default="LR",
370
        help="The rank direction of the pydot graph."
371
    )
372
    args = parser.parse_args()
373
    with open(args.input, 'r') as fid:
374
        content = fid.read()
375
        graphs = utils.GetContentFromProtoString(
376
            content, {
377
                caffe2_pb2.PlanDef: GetOperatorMapForPlan,
378
                caffe2_pb2.NetDef: lambda x: {x.name: x.op},
379
            }
380
        )
381
    for key, operators in graphs.items():
382
        if args.minimal:
383
            graph = GetPydotGraphMinimal(
384
                operators,
385
                name=key,
386
                rankdir=args.rankdir,
387
                node_producer=GetOpNodeProducer(args.append_output, **OP_STYLE),
388
                minimal_dependency=args.minimal_dependency)
389
        else:
390
            graph = GetPydotGraph(
391
                operators,
392
                name=key,
393
                rankdir=args.rankdir,
394
                node_producer=GetOpNodeProducer(args.append_output, **OP_STYLE))
395
        filename = args.output_prefix + graph.get_name() + '.dot'
396
        graph.write(filename, format='raw')
397
        pdf_filename = filename[:-3] + 'pdf'
398
        try:
399
            graph.write_pdf(pdf_filename)
400
        except Exception:
401
            print(
402
                'Error when writing out the pdf file. Pydot requires graphviz '
403
                'to convert dot files to pdf, and you may not have installed '
404
                'graphviz. On ubuntu this can usually be installed with "sudo '
405
                'apt-get install graphviz". We have generated the .dot file '
406
                'but will not be able to generate pdf file for now.'
407
            )
408

409

410
if __name__ == '__main__':
411
    main()
412

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.