colossalai
49 строк · 1.4 Кб
1import torch2
3from colossalai.fx import ColoTracer4from colossalai.fx.passes.utils import assign_bfs_level_to_nodes, get_leaf, get_top5from colossalai.testing import clear_cache_before_run6
7
8class MLP(torch.nn.Module):9def __init__(self, dim: int):10super().__init__()11self.linear1 = torch.nn.Linear(dim, dim)12self.linear2 = torch.nn.Linear(dim, dim)13self.linear3 = torch.nn.Linear(dim, dim)14self.linear4 = torch.nn.Linear(dim, dim)15self.linear5 = torch.nn.Linear(dim, dim)16
17def forward(self, x):18l1 = self.linear1(x)19l2 = self.linear2(x)20l3 = self.linear3(l1)21l4 = self.linear4(l2)22l5 = self.linear5(l3)23return l4, l524
25
26@clear_cache_before_run()27def test_graph_manipulation():28model = MLP(4)29tracer = ColoTracer()30graph = tracer.trace(model)31nodes = list(graph.nodes)32x, l1, l2, l3, l4, l5, output = nodes33
34leaf_nodes = set(get_leaf(graph))35top_nodes = set(get_top(graph))36compare_dict = {x: None, l1: 0, l2: 0, l3: 1, l4: 1, l5: 2, output: None}37assign_bfs_level_to_nodes(graph)38
39assert leaf_nodes == set([l4, l5])40assert top_nodes == set([l1, l2])41for node in graph.nodes:42if node.op in ("placeholder", "output"):43assert not hasattr(node, "bfs_level")44else:45assert node.bfs_level == compare_dict[node]46
47
48if __name__ == "__main__":49test_graph_manipulation()50