7
from caffe2.python import core, test_util, workspace
8
from caffe2.python.control_ops_grad import disambiguate_grad_if_op_output
9
from caffe2.python.model_helper import ModelHelper
13
class TestControl(test_util.TestCase):
14
def test_disambiguate_grad_if_op_output(self):
15
workspace.FeedBlob("cond", np.array(True))
16
workspace.FeedBlob("then_grad", np.array(1))
17
workspace.FeedBlob("else_grad", np.array(2))
19
then_model = ModelHelper(name="then_test_model")
20
then_model.net.Copy("then_grad", "input_grad")
22
else_model = ModelHelper(name="else_test_model")
23
else_model.net.Copy("else_grad", "else_temp_grad")
24
else_model.net.Copy("else_temp", "input_grad")
28
grad_op = core.CreateOperator(
30
["cond", "then_grad", "else_grad"],
31
["input_grad", "else_temp_grad"],
32
then_net=then_model.net.Proto(),
33
else_net=else_model.net.Proto(),
38
new_grad_output = "input_grad" + "_autosplit_" + "0"
39
disambiguate_grad_if_op_output(grad_op, 0, new_grad_output)
40
self.assertEqual(grad_op.output[0], new_grad_output)
41
for arg in grad_op.arg:
42
if arg.name == "else_net":
43
self.assertEqual(arg.n.op[1].output[0], new_grad_output)
45
self.assertEqual(arg.name, "then_net")
48
if __name__ == '__main__':