pytorch

Форк
0
/
net_builder_test.py 
327 строк · 11.1 Кб
1

2

3

4

5

6
from caffe2.python import workspace
7
from caffe2.python.core import Plan, to_execution_step, Net
8
from caffe2.python.task import Task, TaskGroup, final_output
9
from caffe2.python.net_builder import ops, NetBuilder
10
from caffe2.python.session import LocalSession
11
import unittest
12
import threading
13

14

15
class PythonOpStats:
16
    lock = threading.Lock()
17
    num_instances = 0
18
    num_calls = 0
19

20

21
def python_op_builder():
22
    PythonOpStats.lock.acquire()
23
    PythonOpStats.num_instances += 1
24
    PythonOpStats.lock.release()
25

26
    def my_op(inputs, outputs):
27
        PythonOpStats.lock.acquire()
28
        PythonOpStats.num_calls += 1
29
        PythonOpStats.lock.release()
30

31
    return my_op
32

33

34
def _test_loop():
35
    x = ops.Const(5)
36
    y = ops.Const(0)
37
    with ops.loop():
38
        ops.stop_if(ops.EQ([x, ops.Const(0)]))
39
        ops.Add([x, ops.Const(-1)], [x])
40
        ops.Add([y, ops.Const(1)], [y])
41
    return y
42

43

44
def _test_inner_stop(x):
45
    ops.stop_if(ops.LT([x, ops.Const(5)]))
46

47

48
def _test_outer():
49
    x = ops.Const(10)
50
    # test stop_if(False)
51
    with ops.stop_guard() as g1:
52
        _test_inner_stop(x)
53

54
    # test stop_if(True)
55
    y = ops.Const(3)
56
    with ops.stop_guard() as g2:
57
        _test_inner_stop(y)
58

59
    # test no stop
60
    with ops.stop_guard() as g4:
61
        ops.Const(0)
62

63
    # test empty clause
64
    with ops.stop_guard() as g3:
65
        pass
66

67
    return (
68
        g1.has_stopped(), g2.has_stopped(), g3.has_stopped(), g4.has_stopped())
69

70

71
def _test_if(x):
72
    y = ops.Const(1)
73
    with ops.If(ops.GT([x, ops.Const(50)])):
74
        ops.Const(2, blob_out=y)
75
    with ops.If(ops.LT([x, ops.Const(50)])):
76
        ops.Const(3, blob_out=y)
77
        ops.stop()
78
        ops.Const(4, blob_out=y)
79
    return y
80

81

82
class TestNetBuilder(unittest.TestCase):
83
    def test_ops(self):
84
        with NetBuilder() as nb:
85
            y = _test_loop()
86
            z, w, a, b = _test_outer()
87
            p = _test_if(ops.Const(75))
88
            q = _test_if(ops.Const(25))
89
        plan = Plan('name')
90
        plan.AddStep(to_execution_step(nb))
91
        ws = workspace.C.Workspace()
92
        ws.run(plan)
93
        expected_results = [
94
            (y, 5),
95
            (z, False),
96
            (w, True),
97
            (a, False),
98
            (b, False),
99
            (p, 2),
100
            (q, 3),
101
        ]
102
        for b, expected in expected_results:
103
            actual = ws.blobs[str(b)].fetch()
104
            self.assertEqual(actual, expected)
105

106
    def _expected_loop(self):
107
        total = 0
108
        total_large = 0
109
        total_small = 0
110
        total_tiny = 0
111
        for loop_iter in range(10):
112
            outer = loop_iter * 10
113
            for inner_iter in range(loop_iter):
114
                val = outer + inner_iter
115
                if val >= 80:
116
                    total_large += val
117
                elif val >= 50:
118
                    total_small += val
119
                else:
120
                    total_tiny += val
121
                total += val
122
        return total, total_large, total_small, total_tiny
123

124
    def _actual_loop(self):
125
        total = ops.Const(0)
126
        total_large = ops.Const(0)
127
        total_small = ops.Const(0)
128
        total_tiny = ops.Const(0)
129
        with ops.loop(10) as loop:
130
            outer = ops.Mul([loop.iter(), ops.Const(10)])
131
            with ops.loop(loop.iter()) as inner:
132
                val = ops.Add([outer, inner.iter()])
133
                with ops.If(ops.GE([val, ops.Const(80)])) as c:
134
                    ops.Add([total_large, val], [total_large])
135
                with c.Elif(ops.GE([val, ops.Const(50)])) as c:
136
                    ops.Add([total_small, val], [total_small])
137
                with c.Else():
138
                    ops.Add([total_tiny, val], [total_tiny])
139
                ops.Add([total, val], total)
140
        return [
141
            final_output(x)
142
            for x in [total, total_large, total_small, total_tiny]
143
        ]
144

145
    def test_net_multi_use(self):
146
        with Task() as task:
147
            total = ops.Const(0)
148
            net = Net('my_net')
149
            net.Add([total, net.Const(1)], [total])
150
            ops.net(net)
151
            ops.net(net)
152
            result = final_output(total)
153
        with LocalSession() as session:
154
            session.run(task)
155
            self.assertEqual(2, result.fetch())
156

157
    def test_loops(self):
158
        with Task() as task:
159
            out_actual = self._actual_loop()
160
        with LocalSession() as session:
161
            session.run(task)
162
            expected = self._expected_loop()
163
            actual = [o.fetch() for o in out_actual]
164
            for e, a in zip(expected, actual):
165
                self.assertEqual(e, a)
166

167
    def test_setup(self):
168
        with Task() as task:
169
            with ops.task_init():
170
                one = ops.Const(1)
171
            two = ops.Add([one, one])
172
            with ops.task_init():
173
                three = ops.Const(3)
174
            accum = ops.Add([two, three])
175
            # here, accum should be 5
176
            with ops.task_exit():
177
                # here, accum should be 6, since this executes after lines below
178
                seven_1 = ops.Add([accum, one])
179
            six = ops.Add([accum, one])
180
            ops.Add([accum, one], [accum])
181
            seven_2 = ops.Add([accum, one])
182
            o6 = final_output(six)
183
            o7_1 = final_output(seven_1)
184
            o7_2 = final_output(seven_2)
185
        with LocalSession() as session:
186
            session.run(task)
187
            self.assertEqual(o6.fetch(), 6)
188
            self.assertEqual(o7_1.fetch(), 7)
189
            self.assertEqual(o7_2.fetch(), 7)
190

191
    def test_multi_instance_python_op(self):
192
        """
193
        When task instances are created at runtime, C++ concurrently creates
194
        multiple instances of operators in C++, and concurrently destroys them
195
        once the task is finished. This means that the destructor of PythonOp
196
        will be called concurrently, so the GIL must be acquired. This
197
        test exercises this condition.
198
        """
199
        with Task(num_instances=64) as task:
200
            with ops.loop(4):
201
                ops.Python((python_op_builder, [], {}))([], [])
202
        with LocalSession() as session:
203
            PythonOpStats.num_instances = 0
204
            PythonOpStats.num_calls = 0
205
            session.run(task)
206
            self.assertEqual(PythonOpStats.num_instances, 64)
207
            self.assertEqual(PythonOpStats.num_calls, 256)
208

209
    def test_multi_instance(self):
210
        NUM_INSTANCES = 10
211
        NUM_ITERS = 15
212
        with TaskGroup() as tg:
213
            with Task(num_instances=NUM_INSTANCES):
214
                with ops.task_init():
215
                    counter1 = ops.CreateCounter([], ['global_counter'])
216
                    counter2 = ops.CreateCounter([], ['global_counter2'])
217
                    counter3 = ops.CreateCounter([], ['global_counter3'])
218
                # both task_counter and local_counter should be thread local
219
                with ops.task_instance_init():
220
                    task_counter = ops.CreateCounter([], ['task_counter'])
221
                local_counter = ops.CreateCounter([], ['local_counter'])
222
                with ops.loop(NUM_ITERS):
223
                    ops.CountUp(counter1)
224
                    ops.CountUp(task_counter)
225
                    ops.CountUp(local_counter)
226
                # gather sum of squares of local counters to make sure that
227
                # each local counter counted exactly up to NUM_ITERS, and
228
                # that there was no false sharing of counter instances.
229
                with ops.task_instance_exit():
230
                    count2 = ops.RetrieveCount(task_counter)
231
                    with ops.loop(ops.Mul([count2, count2])):
232
                        ops.CountUp(counter2)
233
                # This should have the same effect as the above
234
                count3 = ops.RetrieveCount(local_counter)
235
                with ops.loop(ops.Mul([count3, count3])):
236
                    ops.CountUp(counter3)
237
                # The code below will only run once
238
                with ops.task_exit():
239
                    total1 = final_output(ops.RetrieveCount(counter1))
240
                    total2 = final_output(ops.RetrieveCount(counter2))
241
                    total3 = final_output(ops.RetrieveCount(counter3))
242

243
        with LocalSession() as session:
244
            session.run(tg)
245
            self.assertEqual(total1.fetch(), NUM_INSTANCES * NUM_ITERS)
246
            self.assertEqual(total2.fetch(), NUM_INSTANCES * (NUM_ITERS ** 2))
247
            self.assertEqual(total3.fetch(), NUM_INSTANCES * (NUM_ITERS ** 2))
248

249
    def test_if_net(self):
250
        with NetBuilder() as nb:
251
            x0 = ops.Const(0)
252
            x1 = ops.Const(1)
253
            x2 = ops.Const(2)
254
            y0 = ops.Const(0)
255
            y1 = ops.Const(1)
256
            y2 = ops.Const(2)
257

258
            # basic logic
259
            first_res = ops.Const(0)
260
            with ops.IfNet(ops.Const(True)):
261
                ops.Const(1, blob_out=first_res)
262
            with ops.Else():
263
                ops.Const(2, blob_out=first_res)
264

265
            second_res = ops.Const(0)
266
            with ops.IfNet(ops.Const(False)):
267
                ops.Const(1, blob_out=second_res)
268
            with ops.Else():
269
                ops.Const(2, blob_out=second_res)
270

271
            # nested and sequential ifs,
272
            # empty then/else,
273
            # passing outer blobs into branches,
274
            # writing into outer blobs, incl. into input blob
275
            # using local blobs
276
            with ops.IfNet(ops.LT([x0, x1])):
277
                local_blob = ops.Const(900)
278
                ops.Add([ops.Const(100), local_blob], [y0])
279

280
                gt = ops.GT([x1, x2])
281
                with ops.IfNet(gt):
282
                    # empty then
283
                    pass
284
                with ops.Else():
285
                    ops.Add([y1, local_blob], [local_blob])
286
                    ops.Add([ops.Const(100), y1], [y1])
287

288
                with ops.IfNet(ops.EQ([local_blob, ops.Const(901)])):
289
                    ops.Const(7, blob_out=y2)
290
                    ops.Add([y1, y2], [y2])
291
            with ops.Else():
292
                # empty else
293
                pass
294

295
        plan = Plan('if_net_test')
296
        plan.AddStep(to_execution_step(nb))
297
        ws = workspace.C.Workspace()
298
        ws.run(plan)
299

300
        first_res_value = ws.blobs[str(first_res)].fetch()
301
        second_res_value = ws.blobs[str(second_res)].fetch()
302
        y0_value = ws.blobs[str(y0)].fetch()
303
        y1_value = ws.blobs[str(y1)].fetch()
304
        y2_value = ws.blobs[str(y2)].fetch()
305

306
        self.assertEqual(first_res_value, 1)
307
        self.assertEqual(second_res_value, 2)
308
        self.assertEqual(y0_value, 1000)
309
        self.assertEqual(y1_value, 101)
310
        self.assertEqual(y2_value, 108)
311
        self.assertTrue(str(local_blob) not in ws.blobs)
312

313
    def test_while_net(self):
314
        with NetBuilder() as nb:
315
            x = ops.Const(0)
316
            y = ops.Const(0)
317
            with ops.WhileNet():
318
                with ops.Condition():
319
                    ops.Add([x, ops.Const(1)], [x])
320
                    ops.LT([x, ops.Const(7)])
321
                ops.Add([x, y], [y])
322

323
        plan = Plan('while_net_test')
324
        plan.AddStep(to_execution_step(nb))
325
        ws = workspace.C.Workspace()
326
        ws.run(plan)
327

328
        x_value = ws.blobs[str(x)].fetch()
329
        y_value = ws.blobs[str(y)].fetch()
330

331
        self.assertEqual(x_value, 7)
332
        self.assertEqual(y_value, 21)
333

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.