pytorch

Форк
0
/
test_fx_xform_observer.py 
60 строк · 1.7 Кб
1
# Owner(s): ["module: fx"]
2

3
import os
4
import tempfile
5

6
import torch
7
from torch.fx import subgraph_rewriter, symbolic_trace
8
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
9
from torch.testing._internal.common_utils import TestCase
10

11

12
if __name__ == "__main__":
13
    raise RuntimeError(
14
        "This test file is not meant to be run directly, use:\n\n"
15
        "\tpython test/test_fx.py TESTNAME\n\n"
16
        "instead."
17
    )
18

19

20
class TestGraphTransformObserver(TestCase):
21
    def test_graph_transform_observer(self):
22
        class M(torch.nn.Module):
23
            def forward(self, x):
24
                val = torch.neg(x)
25
                return torch.add(val, val)
26

27
        def pattern(x):
28
            return torch.neg(x)
29

30
        def replacement(x):
31
            return torch.relu(x)
32

33
        traced = symbolic_trace(M())
34

35
        log_url = tempfile.mkdtemp()
36

37
        with GraphTransformObserver(traced, "replace_neg_with_relu", log_url) as ob:
38
            subgraph_rewriter.replace_pattern(traced, pattern, replacement)
39

40
            self.assertTrue("relu" in ob.created_nodes)
41
            self.assertTrue("neg" in ob.erased_nodes)
42

43
        current_pass_count = GraphTransformObserver.get_current_pass_count()
44

45
        self.assertTrue(
46
            os.path.isfile(
47
                os.path.join(
48
                    log_url,
49
                    f"pass_{current_pass_count}_replace_neg_with_relu_input_graph.dot",
50
                )
51
            )
52
        )
53
        self.assertTrue(
54
            os.path.isfile(
55
                os.path.join(
56
                    log_url,
57
                    f"pass_{current_pass_count}_replace_neg_with_relu_output_graph.dot",
58
                )
59
            )
60
        )
61

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.