10
from collections import defaultdict
11
from caffe2.python import utils
13
logger = logging.getLogger(__name__)
14
logger.setLevel(logging.INFO)
20
'Cannot import pydot, which is required for drawing a network. This '
21
'can usually be installed in python with "pip install pydot". Also, '
22
'pydot requires graphviz to convert dot files to pdf: in ubuntu, this '
23
'can usually be installed with "sudo apt-get install graphviz".'
26
'net_drawer will not run correctly. Please install the correct '
31
from caffe2.proto import caffe2_pb2
37
'fontcolor': '#FFFFFF'
39
BLOB_STYLE = {'shape': 'octagon'}
42
def _rectify_operator_and_name(operators_or_net, name):
43
"""Gets the operators and name for the pydot graph."""
44
if isinstance(operators_or_net, caffe2_pb2.NetDef):
45
operators = operators_or_net.op
47
name = operators_or_net.name
48
elif hasattr(operators_or_net, 'Proto'):
49
net = operators_or_net.Proto()
50
if not isinstance(net, caffe2_pb2.NetDef):
52
"Expecting NetDef, but got {}".format(type(net)))
57
operators = operators_or_net
60
return operators, name
63
def _escape_label(name):
65
return json.dumps(name)
68
def GetOpNodeProducer(append_output, **kwargs):
69
def ReallyGetOpNode(op, op_id):
71
node_name = '%s/%s (op#%d)' % (op.name, op.type, op_id)
73
node_name = '%s (op#%d)' % (op.type, op_id)
75
for output_name in op.output:
76
node_name += '\n' + output_name
77
return pydot.Node(node_name, **kwargs)
78
return ReallyGetOpNode
81
def GetBlobNodeProducer(**kwargs):
82
def ReallyGetBlobNode(node_name, label):
83
return pydot.Node(node_name, label=label, **kwargs)
84
return ReallyGetBlobNode
90
op_node_producer=None,
91
blob_node_producer=None
93
if op_node_producer is None:
94
op_node_producer = GetOpNodeProducer(False, **OP_STYLE)
95
if blob_node_producer is None:
96
blob_node_producer = GetBlobNodeProducer(**BLOB_STYLE)
97
operators, name = _rectify_operator_and_name(operators_or_net, name)
98
graph = pydot.Dot(name, rankdir=rankdir)
100
pydot_node_counts = defaultdict(int)
101
for op_id, op in enumerate(operators):
102
op_node = op_node_producer(op, op_id)
103
graph.add_node(op_node)
107
for input_name in op.input:
108
if input_name not in pydot_nodes:
109
input_node = blob_node_producer(
111
input_name + str(pydot_node_counts[input_name])),
112
label=_escape_label(input_name),
114
pydot_nodes[input_name] = input_node
116
input_node = pydot_nodes[input_name]
117
graph.add_node(input_node)
118
graph.add_edge(pydot.Edge(input_node, op_node))
119
for output_name in op.output:
120
if output_name in pydot_nodes:
122
pydot_node_counts[output_name] += 1
123
output_node = blob_node_producer(
125
output_name + str(pydot_node_counts[output_name])),
126
label=_escape_label(output_name),
128
pydot_nodes[output_name] = output_node
129
graph.add_node(output_node)
130
graph.add_edge(pydot.Edge(op_node, output_node))
134
def GetPydotGraphMinimal(
138
minimal_dependency=False,
139
op_node_producer=None,
141
"""Different from GetPydotGraph, hide all blob nodes and only show op nodes.
143
If minimal_dependency is set as well, for each op, we will only draw the
144
edges to the minimal necessary ancestors. For example, if op c depends on
145
op a and b, and op b depends on a, then only the edge b->c will be drawn
146
because a->c will be implied.
148
if op_node_producer is None:
149
op_node_producer = GetOpNodeProducer(False, **OP_STYLE)
150
operators, name = _rectify_operator_and_name(operators_or_net, name)
151
graph = pydot.Dot(name, rankdir=rankdir)
155
op_ancestry = defaultdict(set)
156
for op_id, op in enumerate(operators):
157
op_node = op_node_producer(op, op_id)
158
graph.add_node(op_node)
161
blob_parents[input_name] for input_name in op.input
162
if input_name in blob_parents
164
op_ancestry[op_node].update(parents)
166
op_ancestry[op_node].update(op_ancestry[node])
167
if minimal_dependency:
171
[node not in op_ancestry[other_node]
172
for other_node in parents]
174
graph.add_edge(pydot.Edge(node, op_node))
178
graph.add_edge(pydot.Edge(node, op_node))
180
for output_name in op.output:
181
blob_parents[output_name] = op_node
185
def GetOperatorMapForPlan(plan_def):
187
for net_id, net in enumerate(plan_def.network):
188
if net.HasField('name'):
189
operator_map[plan_def.name + "_" + net.name] = net.op
191
operator_map[plan_def.name + "_network_%d" % net_id] = net.op
195
def _draw_nets(nets, g):
197
for i, net in enumerate(nets):
198
nodes.append(pydot.Node(_escape_label(net)))
199
g.add_node(nodes[-1])
201
g.add_edge(pydot.Edge(nodes[-2], nodes[-1]))
205
def _draw_steps(steps, g, skip_step_edges=False):
206
kMaxParallelSteps = 3
209
label = [step.name + '\n']
211
label.append('Reporter: {}'.format(step.report_net))
212
if step.should_stop_blob:
213
label.append('Stopper: {}'.format(step.should_stop_blob))
214
if step.concurrent_substeps:
215
label.append('Concurrent')
218
return '\n'.join(label)
220
def substep_edge(start, end):
221
return pydot.Edge(start, end, arrowhead='dot', style='dashed')
224
for i, step in enumerate(steps):
225
parallel = step.concurrent_substeps
227
nodes.append(pydot.Node(_escape_label(get_label()), **OP_STYLE))
228
g.add_node(nodes[-1])
230
if i > 0 and not skip_step_edges:
231
g.add_edge(pydot.Edge(nodes[-2], nodes[-1]))
234
sub_nodes = _draw_nets(step.network, g)
237
sub_nodes = _draw_steps(
238
step.substep[:kMaxParallelSteps], g, skip_step_edges=True)
240
sub_nodes = _draw_steps(step.substep, g)
242
raise ValueError('invalid step')
246
g.add_edge(substep_edge(nodes[-1], sn))
247
if len(step.substep) > kMaxParallelSteps:
248
ellipsis = pydot.Node('{} more steps'.format(
249
len(step.substep) - kMaxParallelSteps), **OP_STYLE)
251
g.add_edge(substep_edge(nodes[-1], ellipsis))
253
g.add_edge(substep_edge(nodes[-1], sub_nodes[0]))
258
def GetPlanGraph(plan_def, name=None, rankdir='TB'):
259
graph = pydot.Dot(name, rankdir=rankdir)
260
_draw_steps(plan_def.execution_step, graph)
264
def GetGraphInJson(operators_or_net, output_filepath):
265
operators, _ = _rectify_operator_and_name(operators_or_net, None)
266
blob_strid_to_node_id = {}
267
node_name_counts = defaultdict(int)
270
for op_id, op in enumerate(operators):
271
op_label = op.name + '/' + op.type if op.name else op.type
272
op_node_id = len(nodes)
279
for input_name in op.input:
280
strid = _escape_label(
281
input_name + str(node_name_counts[input_name]))
282
if strid not in blob_strid_to_node_id:
288
blob_strid_to_node_id[strid] = len(nodes)
289
nodes.append(input_node)
291
input_node = nodes[blob_strid_to_node_id[strid]]
293
'source': blob_strid_to_node_id[strid],
296
for output_name in op.output:
297
strid = _escape_label(
298
output_name + str(node_name_counts[output_name]))
299
if strid in blob_strid_to_node_id:
301
node_name_counts[output_name] += 1
302
strid = _escape_label(
303
output_name + str(node_name_counts[output_name]))
305
if strid not in blob_strid_to_node_id:
308
'label': output_name,
311
blob_strid_to_node_id[strid] = len(nodes)
312
nodes.append(output_node)
314
'source': op_node_id,
315
'target': blob_strid_to_node_id[strid]
318
with open(output_filepath, 'w') as f:
319
json.dump({'nodes': nodes, 'edges': edges}, f)
325
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00'
326
b'\x01\x01\x00\x00\x00\x007n\xf9$\x00\x00\x00\nIDATx\x9cc`\x00\x00'
327
b'\x00\x02\x00\x01H\xaf\xa4q\x00\x00\x00\x00IEND\xaeB`\x82')
330
def GetGraphPngSafe(func, *args, **kwargs):
332
Invokes `func` (e.g. GetPydotGraph) with args. If anything fails - returns
333
and empty image instead of throwing Exception
336
graph = func(*args, **kwargs)
337
if not isinstance(graph, pydot.Dot):
338
raise ValueError("func is expected to return pydot.Dot")
339
return graph.create_png()
340
except Exception as e:
341
logger.error("Failed to draw graph: {}".format(e))
342
return _DummyPngImage
346
parser = argparse.ArgumentParser(description="Caffe2 net drawer.")
349
type=str, required=True,
350
help="The input protobuf file."
354
type=str, default="",
355
help="The prefix to be added to the output filename."
358
"--minimal", action="store_true",
359
help="If set, produce a minimal visualization."
362
"--minimal_dependency", action="store_true",
363
help="If set, only draw minimal dependency."
366
"--append_output", action="store_true",
367
help="If set, append the output blobs to the operator names.")
369
"--rankdir", type=str, default="LR",
370
help="The rank direction of the pydot graph."
372
args = parser.parse_args()
373
with open(args.input, 'r') as fid:
375
graphs = utils.GetContentFromProtoString(
377
caffe2_pb2.PlanDef: GetOperatorMapForPlan,
378
caffe2_pb2.NetDef: lambda x: {x.name: x.op},
381
for key, operators in graphs.items():
383
graph = GetPydotGraphMinimal(
386
rankdir=args.rankdir,
387
node_producer=GetOpNodeProducer(args.append_output, **OP_STYLE),
388
minimal_dependency=args.minimal_dependency)
390
graph = GetPydotGraph(
393
rankdir=args.rankdir,
394
node_producer=GetOpNodeProducer(args.append_output, **OP_STYLE))
395
filename = args.output_prefix + graph.get_name() + '.dot'
396
graph.write(filename, format='raw')
397
pdf_filename = filename[:-3] + 'pdf'
399
graph.write_pdf(pdf_filename)
402
'Error when writing out the pdf file. Pydot requires graphviz '
403
'to convert dot files to pdf, and you may not have installed '
404
'graphviz. On ubuntu this can usually be installed with "sudo '
405
'apt-get install graphviz". We have generated the .dot file '
406
'but will not be able to generate pdf file for now.'
410
if __name__ == '__main__':