pytorch

Форк
0
124 строки · 3.4 Кб
1
# Copyright (c) Meta Platforms, Inc. and affiliates
2
# Owner(s): ["oncall: distributed"]
3
from model_registry import MLPModule, ModelWithParamAlias
4

5
import torch
6
from torch.distributed.pipelining import pipe_split, pipeline
7
from torch.testing._internal.common_utils import (
8
    instantiate_parametrized_tests,
9
    parametrize,
10
    run_tests,
11
    TestCase,
12
)
13

14

15
d_hid = 512
16
microbatch_size = 16
17

18
torch.manual_seed(0)
19

20

21
# Basic example
22
class ExampleCode(torch.nn.Module):
23
    def __init__(self) -> None:
24
        super().__init__()
25
        self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
26
        self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
27
        self.lin1 = torch.nn.Linear(d_hid, d_hid)
28
        self.lin2 = torch.nn.Linear(d_hid, d_hid)
29

30
    def forward(self, x, y):
31
        x = torch.mm(x, self.mm_param1)  # mutli-use param
32
        skip_connection = x
33
        x = x + y
34
        x = torch.relu(x)
35
        pipe_split()
36
        x = torch.mm(x, self.mm_param1)  # mutli-use param
37
        x = self.lin1(x)
38
        pipe_split()
39
        x = torch.relu(x)
40
        x = x + skip_connection
41
        x = torch.mm(x, self.mm_param2)
42
        pipe_split()
43
        x = self.lin2(x)
44
        x = torch.relu(x)
45
        return x
46

47

48
class MultiMLP(torch.nn.Module):
49
    def __init__(self) -> None:
50
        super().__init__()
51
        self.mlp0 = MLPModule(d_hid)
52
        self.mlp1 = MLPModule(d_hid)
53
        self.mlp2 = MLPModule(d_hid)
54
        self.mlp3 = MLPModule(d_hid)
55

56
    def forward(self, x, y):
57
        x = self.mlp0(x)
58
        pipe_split()
59
        x = self.mlp1(x)
60
        pipe_split()
61
        x = self.mlp2(x)
62
        pipe_split()
63
        x = self.mlp3(x)
64
        return x - y
65

66

67
EXPECTED_N_STAGES = {
68
    ExampleCode: 4,
69
    MultiMLP: 4,
70
    ModelWithParamAlias: 2,
71
}
72

73
# Currently, we don't enforce full set equality on the FQNs between the original
74
# and pipelined models, because in the multi-use param case, PP will deduplicate
75
# the FQNs from the state_dict.
76
# TODO
77
CHECK_FQN_SET_EQUALITY = False
78

79

80
class PipeTests(TestCase):
81
    @parametrize("ModelClass", [ExampleCode, MultiMLP, ModelWithParamAlias])
82
    def test_model_split(self, ModelClass):
83
        mod = ModelClass()
84
        x = torch.randn(microbatch_size, d_hid)
85
        y = torch.randn(microbatch_size, d_hid)
86

87
        pipe = pipeline(
88
            mod,
89
            mb_args=(x, y),
90
        )
91

92
        assert (
93
            pipe.num_stages == EXPECTED_N_STAGES[ModelClass]
94
        ), f"nstages = {pipe.num_stages}, expect {EXPECTED_N_STAGES[ModelClass]}"
95

96
        ref_out = mod(x, y)
97
        out = pipe(x, y)[0]
98
        torch.testing.assert_close(out, ref_out)
99
        print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}")
100

101
        # Check qualname
102
        # state_dict.keys include both parameters and persistent buffers
103
        old_names = set(mod.state_dict().keys())
104
        new_names = set()
105
        for idx in range(pipe.num_stages):
106
            stage_mod = pipe.get_stage_module(idx)
107
            stage_fqns = set(stage_mod.state_dict().keys())
108
            assert stage_fqns.issubset(old_names)
109
            new_names.update(stage_fqns)
110

111
        if CHECK_FQN_SET_EQUALITY:
112
            assert (
113
                old_names == new_names
114
            ), f"""
115
            old names {old_names}
116
            new names {new_names}
117
            """
118
        print("Qualname check passed")
119

120

121
instantiate_parametrized_tests(PipeTests)
122

123
if __name__ == "__main__":
124
    run_tests()
125

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.