colossalai

Форк
0
/
test_graph_manipulation.py 
49 строк · 1.4 Кб
1
import torch
2

3
from colossalai.fx import ColoTracer
4
from colossalai.fx.passes.utils import assign_bfs_level_to_nodes, get_leaf, get_top
5
from colossalai.testing import clear_cache_before_run
6

7

8
class MLP(torch.nn.Module):
9
    def __init__(self, dim: int):
10
        super().__init__()
11
        self.linear1 = torch.nn.Linear(dim, dim)
12
        self.linear2 = torch.nn.Linear(dim, dim)
13
        self.linear3 = torch.nn.Linear(dim, dim)
14
        self.linear4 = torch.nn.Linear(dim, dim)
15
        self.linear5 = torch.nn.Linear(dim, dim)
16

17
    def forward(self, x):
18
        l1 = self.linear1(x)
19
        l2 = self.linear2(x)
20
        l3 = self.linear3(l1)
21
        l4 = self.linear4(l2)
22
        l5 = self.linear5(l3)
23
        return l4, l5
24

25

26
@clear_cache_before_run()
27
def test_graph_manipulation():
28
    model = MLP(4)
29
    tracer = ColoTracer()
30
    graph = tracer.trace(model)
31
    nodes = list(graph.nodes)
32
    x, l1, l2, l3, l4, l5, output = nodes
33

34
    leaf_nodes = set(get_leaf(graph))
35
    top_nodes = set(get_top(graph))
36
    compare_dict = {x: None, l1: 0, l2: 0, l3: 1, l4: 1, l5: 2, output: None}
37
    assign_bfs_level_to_nodes(graph)
38

39
    assert leaf_nodes == set([l4, l5])
40
    assert top_nodes == set([l1, l2])
41
    for node in graph.nodes:
42
        if node.op in ("placeholder", "output"):
43
            assert not hasattr(node, "bfs_level")
44
        else:
45
            assert node.bfs_level == compare_dict[node]
46

47

48
if __name__ == "__main__":
49
    test_graph_manipulation()
50

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.