pytorch

Форк
0
125 строк · 4.1 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import torch
4
import torch.nn as nn
5
from torch.distributed._tensor import DTensor
6
from torch.distributed.checkpoint.state_dict import get_state_dict
7
from torch.distributed.device_mesh import _mesh_resources, init_device_mesh
8
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
9
from torch.testing._internal.common_utils import run_tests
10
from torch.testing._internal.distributed._tensor.common_dtensor import (
11
    DTensorTestBase,
12
    skip_if_lt_x_gpu,
13
    with_comms,
14
)
15
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
16
from torch.testing._internal.distributed.common_state_dict import VerifyStateDictMixin
17

18

19
class Dummymodel(nn.Module):
20
    def __init__(self):
21
        super().__init__()
22

23
    def forward(self, x):
24
        raise NotImplementedError()
25

26

27
class EPModel(nn.Module):
28
    def __init__(self, rank):
29
        super().__init__()
30
        self.net1 = nn.Sequential(nn.Linear(16, 16), nn.ReLU())
31
        self.net2 = nn.Sequential(nn.Linear(16, 16), nn.ReLU())
32

33
    def forward(self, x):
34
        raise NotImplementedError()
35

36

37
class SecondTier(nn.Module):
38
    def __init__(self, rank):
39
        super().__init__()
40
        self.ep_layers = nn.ModuleList(
41
            [EPModel(rank) if rank % 4 == i else Dummymodel() for i in range(4)]
42
        )
43
        self.net = nn.Sequential(nn.Linear(16, 16), nn.ReLU())
44

45
    def forward(self, x):
46
        raise NotImplementedError()
47

48

49
class TopModel(nn.Module):
50
    def __init__(self, rank):
51
        super().__init__()
52
        torch.manual_seed(0)
53

54
        self.second = SecondTier(rank)
55
        self.net = nn.Sequential(nn.Linear(16, 16), nn.ReLU())
56

57
    def forward(self, x):
58
        raise NotImplementedError()
59

60

61
class TestFSDPWithEP(DTensorTestBase, VerifyStateDictMixin):
62
    @property
63
    def world_size(self) -> int:
64
        return min(8, torch.cuda.device_count())
65

66
    @with_comms
67
    @skip_if_lt_x_gpu(8)
68
    @with_temp_dir
69
    def test_e2e(self):
70
        model = TopModel(self.rank).cuda()
71

72
        mesh_fsdp_tp = init_device_mesh(
73
            self.device_type, (2, 4), mesh_dim_names=("dp", "tp")
74
        )
75
        # TODO: we are using an internal API atm. Change to a publich API once it is ready.
76
        mesh_fsdp_ep = _mesh_resources.create_child_mesh(mesh_fsdp_tp, 0, "dp")
77
        del _mesh_resources.child_to_parent_mapping[mesh_fsdp_ep]
78

79
        mesh_fsdp = init_device_mesh(self.device_type, (8,))
80
        for i, l in enumerate(model.second.ep_layers):
81
            model.second.ep_layers[i] = FSDP(
82
                l, use_orig_params=True, device_mesh=mesh_fsdp_ep
83
            )
84
        model.second = FSDP(model.second, use_orig_params=True, device_mesh=mesh_fsdp)
85
        model = FSDP(model, use_orig_params=True, device_mesh=mesh_fsdp)
86
        optim = torch.optim.Adam(model.parameters(), lr=0.1)
87
        msd, osd = get_state_dict(model, optim)
88

89
        # FSDP only params
90
        for key in (
91
            "net.0.weight",
92
            "net.0.bias",
93
            "second.net.0.weight",
94
            "second.net.0.bias",
95
        ):
96
            msd_v = msd[key]
97
            osd_v = osd["state"][key]["exp_avg"]
98
            for v in (msd_v, osd_v):
99
                self.assertTrue(isinstance(v, DTensor))
100
                self.assertEqual(tuple(v.device_mesh.mesh), tuple(range(8)))
101

102
        # FSDP/EP params
103
        layer = self.rank % 4
104
        ranks = (layer, layer + 4)
105
        for i in range(4):
106
            for key in (
107
                f"second.ep_layers.{i}.net1.0.weight",
108
                f"second.ep_layers.{i}.net1.0.bias",
109
                f"second.ep_layers.{i}.net2.0.weight",
110
                f"second.ep_layers.{i}.net2.0.bias",
111
            ):
112
                if layer != i:
113
                    self.assertTrue(key not in msd)
114
                else:
115
                    msd_v = msd[key]
116
                    osd_v = osd["state"][key]["exp_avg"]
117
                    for v in (msd_v, osd_v):
118
                        self.assertTrue(isinstance(v, DTensor))
119
                        self.assertEqual(tuple(v.device_mesh.mesh), ranks)
120

121
        self.assertEqual(set(osd["state"].keys()), set(msd.keys()))
122

123

124
if __name__ == "__main__":
125
    run_tests()
126

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.