pytorch

Форк
0
/
nomnigraph.py 
138 строк · 4.1 Кб
1

2

3
import errno
4
import os
5
from subprocess import PIPE, Popen
6

7
import caffe2.python._import_c_extension as C
8
from caffe2.proto import caffe2_pb2
9
from caffe2.python import core
10

11

12
class NNModule:
13
    def __init__(self, net=None, device_map=None):
14
        if net is not None:
15
            serialized_proto = None
16
            if isinstance(net, core.Net):
17
                serialized_proto = net.Proto().SerializeToString()
18
            elif isinstance(net, caffe2_pb2.NetDef):
19
                serialized_proto = net.SerializeToString()
20

21
            # Distributed
22
            if device_map is not None:
23
                serialized_device_map = {}
24
                for k in device_map:
25
                    serialized_device_map[k] = device_map[k].SerializeToString()
26
                self._NNModule = C.NNModuleFromProtobufDistributed(
27
                    serialized_proto, serialized_device_map
28
                )
29
            # Default
30
            elif serialized_proto:
31
                self._NNModule, self._OpList = C.NNModuleFromProtobuf(serialized_proto)
32
            else:
33
                raise Exception(
34
                    "NNModule can be constructed with core.Net or caffe2_pb2.NetDef types"
35
                )
36
        else:
37
            self._NNModule = C.NNModule()
38

39
    @property
40
    def dataFlow(self):
41
        return self._NNModule.dataFlow()
42

43
    @property
44
    def controlFlow(self):
45
        return self._NNModule.getExecutionOrder()
46

47
    @property
48
    def nodes(self):
49
        return self._NNModule.dataFlow().nodes
50

51
    @property
52
    def operators(self):
53
        return self._NNModule.dataFlow().operators
54

55
    @property
56
    def tensors(self):
57
        return self._NNModule.dataFlow().tensors
58

59
    def createNode(self, val):
60
        return self._NNModule.dataFlow().createNode(val)
61

62
    def deleteNode(self, node):
63
        return self._NNModule.dataFlow().deleteNode(node)
64

65
    def createEdge(self, a, b):
66
        return self._NNModule.dataFlow().createEdge(a, b)
67

68
    def deleteEdge(self, a, b=None):
69
        if b:
70
            self._NNModule.dataFlow().deleteEdge(a, b)
71
        else:
72
            self._NNModule.dataFlow().deleteEdge(a)
73

74
    def replaceNode(self, old_node, new_node):
75
        return self._NNModule.dataFlow().replaceNode(old_node, new_node)
76

77
    def replaceProducer(self, tensor, new_producer):
78
        C.replaceProducer(tensor, new_producer)
79

80
    def replaceAllUsesWith(self, old_tensor, new_tensor):
81
        C.replaceAllUsesWith(old_tensor, new_tensor)
82

83
    def replaceAsConsumer(self, old_consumer, new_consumer):
84
        C.replaceAsConsumer(old_consumer, new_consumer)
85

86
    def replaceSubgraph(self, subgraph, new_node, inputs, outputs):
87
        self._NNModule.replaceSubgraph(subgraph, new_node, inputs, outputs)
88

89
    def deleteSubgraph(self, subgraph):
90
        self._NNModule.deleteSubgraph(subgraph)
91

92
    def createUniqueDataNode(self, prefix="_unique"):
93
        return self._NNModule.createUniqueDataNode(prefix)
94

95
    def convertToCaffe2Proto(self, old_proto=None):
96
        if not old_proto:
97
            old_proto = caffe2_pb2.NetDef()
98
        output = self._NNModule.convertToCaffe2Proto(old_proto)
99
        new_proto = caffe2_pb2.NetDef()
100
        new_proto.ParseFromString(output)
101
        return new_proto
102

103
    def match(self, pattern):
104
        for n in self.dataFlow.getMutableNodes():
105
            m = C.matchSubgraph(n, pattern)
106
            if m:
107
                yield m
108

109

110
def render(s):
111
    s = str(s)
112
    cmd_exists = lambda x: any(
113
        os.access(os.path.join(path, x), os.X_OK)
114
        for path in os.getenv("PATH", "").split(os.pathsep)
115
    )
116
    if cmd_exists("graph-easy"):
117
        p = Popen("graph-easy", stdin=PIPE)
118
        try:
119
            p.stdin.write(s.encode("utf-8"))
120
        except IOError as e:
121
            if e.errno == errno.EPIPE or e.errno == errno.EINVAL:
122
                pass
123
            else:
124
                # Raise any other error.
125
                raise
126

127
        p.stdin.close()
128
        p.wait()
129
    else:
130
        print(s)
131

132

133
NeuralNetOperator = C.NeuralNetOperator
134
Operator = C.NeuralNetOperator
135
NeuralNetData = C.NeuralNetData
136
Data = C.NeuralNetData
137
NNSubgraph = C.NNSubgraph
138
NNMatchGraph = C.NNMatchGraph
139
Graph = C.Graph
140
Annotation = C.Annotation
141

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.