8
from caffe2.python import workspace, core
9
import caffe2.python.parallel_workers as parallel_workers
15
workspace.RunOperatorOnce(
17
"CreateBlobsQueue", [], [queue], num_blobs=1, capacity=1000
24
workspace.C.Workspace.current.create_blob("blob_" + str(i))
25
workspace.C.Workspace.current.create_blob("status_blob_" + str(i))
26
workspace.C.Workspace.current.create_blob("dequeue_blob")
27
workspace.C.Workspace.current.create_blob("status_blob")
32
def create_worker(queue, get_blob_data):
33
def dummy_worker(worker_id):
34
blob = 'blob_' + str(worker_id)
36
workspace.FeedBlob(blob, get_blob_data(worker_id))
38
workspace.RunOperatorOnce(
40
'SafeEnqueueBlobs', [queue, blob], [blob, 'status_blob_' + str(worker_id)]
47
def dequeue_value(queue):
48
dequeue_blob = 'dequeue_blob'
49
workspace.RunOperatorOnce(
51
"SafeDequeueBlobs", [queue], [dequeue_blob, 'status_blob']
55
return workspace.FetchBlob(dequeue_blob)
58
class ParallelWorkersTest(unittest.TestCase):
59
def testParallelWorkers(self):
60
workspace.ResetWorkspace()
62
queue = create_queue()
63
dummy_worker = create_worker(queue, str)
64
worker_coordinator = parallel_workers.init_workers(dummy_worker)
65
worker_coordinator.start()
68
value = dequeue_value(queue)
70
value in [b'0', b'1'], 'Got unexpected value ' + str(value)
73
self.assertTrue(worker_coordinator.stop())
75
def testParallelWorkersInitFun(self):
76
workspace.ResetWorkspace()
78
queue = create_queue()
79
dummy_worker = create_worker(
80
queue, lambda worker_id: workspace.FetchBlob('data')
82
workspace.FeedBlob('data', 'not initialized')
84
def init_fun(worker_coordinator, global_coordinator):
85
workspace.FeedBlob('data', 'initialized')
87
worker_coordinator = parallel_workers.init_workers(
88
dummy_worker, init_fun=init_fun
90
worker_coordinator.start()
93
value = dequeue_value(queue)
95
value, b'initialized', 'Got unexpected value ' + str(value)
99
worker_coordinator.stop()
101
def testParallelWorkersShutdownFun(self):
102
workspace.ResetWorkspace()
104
queue = create_queue()
105
dummy_worker = create_worker(queue, str)
106
workspace.FeedBlob('data', 'not shutdown')
109
workspace.FeedBlob('data', 'shutdown')
111
worker_coordinator = parallel_workers.init_workers(
112
dummy_worker, shutdown_fun=shutdown_fun
114
worker_coordinator.start()
116
self.assertTrue(worker_coordinator.stop())
118
data = workspace.FetchBlob('data')
119
self.assertEqual(data, b'shutdown', 'Got unexpected value ' + str(data))