6
from caffe2.python import workspace
7
from caffe2.python.core import Plan, to_execution_step, Net
8
from caffe2.python.task import Task, TaskGroup, final_output
9
from caffe2.python.net_builder import ops, NetBuilder
10
from caffe2.python.session import LocalSession
16
lock = threading.Lock()
21
def python_op_builder():
22
PythonOpStats.lock.acquire()
23
PythonOpStats.num_instances += 1
24
PythonOpStats.lock.release()
26
def my_op(inputs, outputs):
27
PythonOpStats.lock.acquire()
28
PythonOpStats.num_calls += 1
29
PythonOpStats.lock.release()
38
ops.stop_if(ops.EQ([x, ops.Const(0)]))
39
ops.Add([x, ops.Const(-1)], [x])
40
ops.Add([y, ops.Const(1)], [y])
44
def _test_inner_stop(x):
45
ops.stop_if(ops.LT([x, ops.Const(5)]))
51
with ops.stop_guard() as g1:
56
with ops.stop_guard() as g2:
60
with ops.stop_guard() as g4:
64
with ops.stop_guard() as g3:
68
g1.has_stopped(), g2.has_stopped(), g3.has_stopped(), g4.has_stopped())
73
with ops.If(ops.GT([x, ops.Const(50)])):
74
ops.Const(2, blob_out=y)
75
with ops.If(ops.LT([x, ops.Const(50)])):
76
ops.Const(3, blob_out=y)
78
ops.Const(4, blob_out=y)
82
class TestNetBuilder(unittest.TestCase):
84
with NetBuilder() as nb:
86
z, w, a, b = _test_outer()
87
p = _test_if(ops.Const(75))
88
q = _test_if(ops.Const(25))
90
plan.AddStep(to_execution_step(nb))
91
ws = workspace.C.Workspace()
102
for b, expected in expected_results:
103
actual = ws.blobs[str(b)].fetch()
104
self.assertEqual(actual, expected)
106
def _expected_loop(self):
111
for loop_iter in range(10):
112
outer = loop_iter * 10
113
for inner_iter in range(loop_iter):
114
val = outer + inner_iter
122
return total, total_large, total_small, total_tiny
124
def _actual_loop(self):
126
total_large = ops.Const(0)
127
total_small = ops.Const(0)
128
total_tiny = ops.Const(0)
129
with ops.loop(10) as loop:
130
outer = ops.Mul([loop.iter(), ops.Const(10)])
131
with ops.loop(loop.iter()) as inner:
132
val = ops.Add([outer, inner.iter()])
133
with ops.If(ops.GE([val, ops.Const(80)])) as c:
134
ops.Add([total_large, val], [total_large])
135
with c.Elif(ops.GE([val, ops.Const(50)])) as c:
136
ops.Add([total_small, val], [total_small])
138
ops.Add([total_tiny, val], [total_tiny])
139
ops.Add([total, val], total)
142
for x in [total, total_large, total_small, total_tiny]
145
def test_net_multi_use(self):
149
net.Add([total, net.Const(1)], [total])
152
result = final_output(total)
153
with LocalSession() as session:
155
self.assertEqual(2, result.fetch())
157
def test_loops(self):
159
out_actual = self._actual_loop()
160
with LocalSession() as session:
162
expected = self._expected_loop()
163
actual = [o.fetch() for o in out_actual]
164
for e, a in zip(expected, actual):
165
self.assertEqual(e, a)
167
def test_setup(self):
169
with ops.task_init():
171
two = ops.Add([one, one])
172
with ops.task_init():
174
accum = ops.Add([two, three])
175
# here, accum should be 5
176
with ops.task_exit():
177
# here, accum should be 6, since this executes after lines below
178
seven_1 = ops.Add([accum, one])
179
six = ops.Add([accum, one])
180
ops.Add([accum, one], [accum])
181
seven_2 = ops.Add([accum, one])
182
o6 = final_output(six)
183
o7_1 = final_output(seven_1)
184
o7_2 = final_output(seven_2)
185
with LocalSession() as session:
187
self.assertEqual(o6.fetch(), 6)
188
self.assertEqual(o7_1.fetch(), 7)
189
self.assertEqual(o7_2.fetch(), 7)
191
def test_multi_instance_python_op(self):
193
When task instances are created at runtime, C++ concurrently creates
194
multiple instances of operators in C++, and concurrently destroys them
195
once the task is finished. This means that the destructor of PythonOp
196
will be called concurrently, so the GIL must be acquired. This
197
test exercises this condition.
199
with Task(num_instances=64) as task:
201
ops.Python((python_op_builder, [], {}))([], [])
202
with LocalSession() as session:
203
PythonOpStats.num_instances = 0
204
PythonOpStats.num_calls = 0
206
self.assertEqual(PythonOpStats.num_instances, 64)
207
self.assertEqual(PythonOpStats.num_calls, 256)
209
def test_multi_instance(self):
212
with TaskGroup() as tg:
213
with Task(num_instances=NUM_INSTANCES):
214
with ops.task_init():
215
counter1 = ops.CreateCounter([], ['global_counter'])
216
counter2 = ops.CreateCounter([], ['global_counter2'])
217
counter3 = ops.CreateCounter([], ['global_counter3'])
218
# both task_counter and local_counter should be thread local
219
with ops.task_instance_init():
220
task_counter = ops.CreateCounter([], ['task_counter'])
221
local_counter = ops.CreateCounter([], ['local_counter'])
222
with ops.loop(NUM_ITERS):
223
ops.CountUp(counter1)
224
ops.CountUp(task_counter)
225
ops.CountUp(local_counter)
226
# gather sum of squares of local counters to make sure that
227
# each local counter counted exactly up to NUM_ITERS, and
228
# that there was no false sharing of counter instances.
229
with ops.task_instance_exit():
230
count2 = ops.RetrieveCount(task_counter)
231
with ops.loop(ops.Mul([count2, count2])):
232
ops.CountUp(counter2)
233
# This should have the same effect as the above
234
count3 = ops.RetrieveCount(local_counter)
235
with ops.loop(ops.Mul([count3, count3])):
236
ops.CountUp(counter3)
237
# The code below will only run once
238
with ops.task_exit():
239
total1 = final_output(ops.RetrieveCount(counter1))
240
total2 = final_output(ops.RetrieveCount(counter2))
241
total3 = final_output(ops.RetrieveCount(counter3))
243
with LocalSession() as session:
245
self.assertEqual(total1.fetch(), NUM_INSTANCES * NUM_ITERS)
246
self.assertEqual(total2.fetch(), NUM_INSTANCES * (NUM_ITERS ** 2))
247
self.assertEqual(total3.fetch(), NUM_INSTANCES * (NUM_ITERS ** 2))
249
def test_if_net(self):
250
with NetBuilder() as nb:
259
first_res = ops.Const(0)
260
with ops.IfNet(ops.Const(True)):
261
ops.Const(1, blob_out=first_res)
263
ops.Const(2, blob_out=first_res)
265
second_res = ops.Const(0)
266
with ops.IfNet(ops.Const(False)):
267
ops.Const(1, blob_out=second_res)
269
ops.Const(2, blob_out=second_res)
271
# nested and sequential ifs,
273
# passing outer blobs into branches,
274
# writing into outer blobs, incl. into input blob
276
with ops.IfNet(ops.LT([x0, x1])):
277
local_blob = ops.Const(900)
278
ops.Add([ops.Const(100), local_blob], [y0])
280
gt = ops.GT([x1, x2])
285
ops.Add([y1, local_blob], [local_blob])
286
ops.Add([ops.Const(100), y1], [y1])
288
with ops.IfNet(ops.EQ([local_blob, ops.Const(901)])):
289
ops.Const(7, blob_out=y2)
290
ops.Add([y1, y2], [y2])
295
plan = Plan('if_net_test')
296
plan.AddStep(to_execution_step(nb))
297
ws = workspace.C.Workspace()
300
first_res_value = ws.blobs[str(first_res)].fetch()
301
second_res_value = ws.blobs[str(second_res)].fetch()
302
y0_value = ws.blobs[str(y0)].fetch()
303
y1_value = ws.blobs[str(y1)].fetch()
304
y2_value = ws.blobs[str(y2)].fetch()
306
self.assertEqual(first_res_value, 1)
307
self.assertEqual(second_res_value, 2)
308
self.assertEqual(y0_value, 1000)
309
self.assertEqual(y1_value, 101)
310
self.assertEqual(y2_value, 108)
311
self.assertTrue(str(local_blob) not in ws.blobs)
313
def test_while_net(self):
314
with NetBuilder() as nb:
318
with ops.Condition():
319
ops.Add([x, ops.Const(1)], [x])
320
ops.LT([x, ops.Const(7)])
323
plan = Plan('while_net_test')
324
plan.AddStep(to_execution_step(nb))
325
ws = workspace.C.Workspace()
328
x_value = ws.blobs[str(x)].fetch()
329
y_value = ws.blobs[str(y)].fetch()
331
self.assertEqual(x_value, 7)
332
self.assertEqual(y_value, 21)