colossalai

Форк
0
/
test_comm_size_compute.py 
52 строки · 1.7 Кб
1
import torch
2
from torch.fx import symbolic_trace
3

4
from colossalai.fx._compatibility import is_compatible_with_meta
5
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, uniform_split_pass
6
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
7
from colossalai.fx.passes.utils import get_comm_size
8
from colossalai.testing import clear_cache_before_run
9

10
is_compatible = is_compatible_with_meta()
11
if is_compatible:
12
    from colossalai.fx.profiler import MetaTensor
13

14
MODEL_DIM = 16
15
BATCH_SIZE = 8
16
PIPELINE_SIZE = 2
17

18

19
class MLP(torch.nn.Module):
20
    def __init__(self, dim: int):
21
        super().__init__()
22
        self.linear1 = torch.nn.Linear(dim, dim)
23
        self.linear2 = torch.nn.Linear(dim, dim)
24
        self.linear3 = torch.nn.Linear(dim, dim)
25
        self.linear4 = torch.nn.Linear(dim, dim)
26

27
    def forward(self, x):
28
        x = self.linear1(x)
29
        x = self.linear2(x)
30
        x = self.linear3(x)
31
        x = self.linear4(x)
32
        return x
33

34

35
@clear_cache_before_run()
36
def test_comm_size_compute():
37
    model = MLP(MODEL_DIM)
38
    input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device="meta")
39
    gm = symbolic_trace(model)
40
    if is_compatible:
41
        input_sample = MetaTensor(input_sample, fake_device=next(gm.parameters()).device)
42
    MetaInfoProp(gm).run(input_sample)
43
    annotated_model = uniform_split_pass(gm, PIPELINE_SIZE)
44
    split_model, split_submodules = split_with_split_nodes_pass(annotated_model)
45
    submodule_list = list(split_model.children())
46
    comm_size = get_comm_size(submodule_list[0], submodule_list[1])
47
    # the shape of tensor send from partition 0 to partition 1 is (8, 16)
48
    assert comm_size == 128
49

50

51
if __name__ == "__main__":
52
    test_comm_size_compute()
53

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.