colossalai
52 строки · 1.5 Кб
1import torch2from torch.fx import symbolic_trace3
4from colossalai.fx.passes.adding_split_node_pass import (5balanced_split_pass,6balanced_split_pass_v2,7split_with_split_nodes_pass,8uniform_split_pass,9)
10from colossalai.testing import clear_cache_before_run11
12MODEL_DIM = 1613BATCH_SIZE = 814PIPELINE_SIZE = 215
16
17class MLP(torch.nn.Module):18def __init__(self, dim: int):19super().__init__()20self.linear1 = torch.nn.Linear(dim, dim)21self.linear2 = torch.nn.Linear(dim, dim)22self.linear3 = torch.nn.Linear(dim, dim)23self.linear4 = torch.nn.Linear(dim, dim)24
25def forward(self, x):26x = self.linear1(x)27x = self.linear2(x)28x = self.linear3(x)29x = self.linear4(x)30return x31
32
33def pipeline_pass_test_helper(model, data, pass_func):34origin_output = model(data)35symbolic_traced = symbolic_trace(model)36annotated_model = pass_func(symbolic_traced, PIPELINE_SIZE)37split_model, split_submodules = split_with_split_nodes_pass(annotated_model)38output = split_model(data)39assert output.equal(origin_output)40
41
42@clear_cache_before_run()43def test_pipeline_passes():44model = MLP(MODEL_DIM)45data = torch.rand(BATCH_SIZE, MODEL_DIM)46pipeline_pass_test_helper(model, data, balanced_split_pass)47pipeline_pass_test_helper(model, data, balanced_split_pass_v2)48pipeline_pass_test_helper(model, data, uniform_split_pass)49
50
51if __name__ == "__main__":52test_pipeline_passes()53