5
from subprocess import PIPE, Popen
7
import caffe2.python._import_c_extension as C
8
from caffe2.proto import caffe2_pb2
9
from caffe2.python import core
13
def __init__(self, net=None, device_map=None):
15
serialized_proto = None
16
if isinstance(net, core.Net):
17
serialized_proto = net.Proto().SerializeToString()
18
elif isinstance(net, caffe2_pb2.NetDef):
19
serialized_proto = net.SerializeToString()
22
if device_map is not None:
23
serialized_device_map = {}
25
serialized_device_map[k] = device_map[k].SerializeToString()
26
self._NNModule = C.NNModuleFromProtobufDistributed(
27
serialized_proto, serialized_device_map
30
elif serialized_proto:
31
self._NNModule, self._OpList = C.NNModuleFromProtobuf(serialized_proto)
34
"NNModule can be constructed with core.Net or caffe2_pb2.NetDef types"
37
self._NNModule = C.NNModule()
41
return self._NNModule.dataFlow()
44
def controlFlow(self):
45
return self._NNModule.getExecutionOrder()
49
return self._NNModule.dataFlow().nodes
53
return self._NNModule.dataFlow().operators
57
return self._NNModule.dataFlow().tensors
59
def createNode(self, val):
60
return self._NNModule.dataFlow().createNode(val)
62
def deleteNode(self, node):
63
return self._NNModule.dataFlow().deleteNode(node)
65
def createEdge(self, a, b):
66
return self._NNModule.dataFlow().createEdge(a, b)
68
def deleteEdge(self, a, b=None):
70
self._NNModule.dataFlow().deleteEdge(a, b)
72
self._NNModule.dataFlow().deleteEdge(a)
74
def replaceNode(self, old_node, new_node):
75
return self._NNModule.dataFlow().replaceNode(old_node, new_node)
77
def replaceProducer(self, tensor, new_producer):
78
C.replaceProducer(tensor, new_producer)
80
def replaceAllUsesWith(self, old_tensor, new_tensor):
81
C.replaceAllUsesWith(old_tensor, new_tensor)
83
def replaceAsConsumer(self, old_consumer, new_consumer):
84
C.replaceAsConsumer(old_consumer, new_consumer)
86
def replaceSubgraph(self, subgraph, new_node, inputs, outputs):
87
self._NNModule.replaceSubgraph(subgraph, new_node, inputs, outputs)
89
def deleteSubgraph(self, subgraph):
90
self._NNModule.deleteSubgraph(subgraph)
92
def createUniqueDataNode(self, prefix="_unique"):
93
return self._NNModule.createUniqueDataNode(prefix)
95
def convertToCaffe2Proto(self, old_proto=None):
97
old_proto = caffe2_pb2.NetDef()
98
output = self._NNModule.convertToCaffe2Proto(old_proto)
99
new_proto = caffe2_pb2.NetDef()
100
new_proto.ParseFromString(output)
103
def match(self, pattern):
104
for n in self.dataFlow.getMutableNodes():
105
m = C.matchSubgraph(n, pattern)
112
cmd_exists = lambda x: any(
113
os.access(os.path.join(path, x), os.X_OK)
114
for path in os.getenv("PATH", "").split(os.pathsep)
116
if cmd_exists("graph-easy"):
117
p = Popen("graph-easy", stdin=PIPE)
119
p.stdin.write(s.encode("utf-8"))
121
if e.errno == errno.EPIPE or e.errno == errno.EINVAL:
133
NeuralNetOperator = C.NeuralNetOperator
134
Operator = C.NeuralNetOperator
135
NeuralNetData = C.NeuralNetData
136
Data = C.NeuralNetData
137
NNSubgraph = C.NNSubgraph
138
NNMatchGraph = C.NNMatchGraph
140
Annotation = C.Annotation