pytorch

Форк
0
/
control_ops_grad_test.py 
44 строки · 1.7 Кб
1

2

3

4

5

6
import unittest
7
from caffe2.python import core, test_util, workspace
8
from caffe2.python.control_ops_grad import disambiguate_grad_if_op_output
9
from caffe2.python.model_helper import ModelHelper
10
import numpy as np
11

12

13
class TestControl(test_util.TestCase):
14
    def test_disambiguate_grad_if_op_output(self):
15
        workspace.FeedBlob("cond", np.array(True))
16
        workspace.FeedBlob("then_grad", np.array(1))
17
        workspace.FeedBlob("else_grad", np.array(2))
18

19
        then_model = ModelHelper(name="then_test_model")
20
        then_model.net.Copy("then_grad", "input_grad")
21

22
        else_model = ModelHelper(name="else_test_model")
23
        else_model.net.Copy("else_grad", "else_temp_grad")
24
        else_model.net.Copy("else_temp", "input_grad")
25

26
        # to BuildGradientGenerators, in forward pass, we need else temp
27
        # as one of the output. Which later on results in a grad op like this:
28
        grad_op = core.CreateOperator(
29
            "If",
30
            ["cond", "then_grad", "else_grad"],
31
            ["input_grad", "else_temp_grad"],
32
            then_net=then_model.net.Proto(),
33
            else_net=else_model.net.Proto(),
34
        )
35

36
        # in certain cases, another branch of the net also generates input_grad
37
        # and we call _DisambiguateGradOpOutput in core.py
38
        new_grad_output = "input_grad" + "_autosplit_" + "0"
39
        disambiguate_grad_if_op_output(grad_op, 0, new_grad_output)
40
        self.assertEqual(grad_op.output[0], new_grad_output)
41
        for arg in grad_op.arg:
42
            if arg.name == "else_net":
43
                self.assertEqual(arg.n.op[1].output[0], new_grad_output)
44
            else:
45
                self.assertEqual(arg.name, "then_net")
46

47

48
if __name__ == '__main__':
49
    unittest.main()
50

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.