pytorch

Форк
0
/
test_sharder.py 
188 строк · 6.2 Кб
1
# Owner(s): ["oncall: distributed"]
2
import copy
3
import sys
4

5
import torch
6
import torch.nn as nn
7
from torch.distributed._shard import shard_module
8
from torch.distributed._shard.sharded_tensor import ShardedTensor
9
from torch.distributed._shard.sharder import Sharder
10
from torch.distributed._shard.sharding_plan import ShardingPlan
11
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
12
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
13
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
14
from torch.testing._internal.distributed._shard.sharded_tensor import (
15
    ShardedTensorTestBase,
16
    TEST_GPU_NUM,
17
    with_comms,
18
)
19

20

21
if TEST_WITH_DEV_DBG_ASAN:
22
    print(
23
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
24
        file=sys.stderr,
25
    )
26
    sys.exit(0)
27

28

29
# a simple collection of embedding bag implementation
30
class CustomEmbeddingBagCollection(nn.Module):
31
    def __init__(self, num_bags, num_embeddings_per_bag, num_dims):
32
        super().__init__()
33
        self.num_bags = num_bags
34
        self.embedding_bags: nn.ModuleDict = nn.ModuleDict()
35

36
        for i in range(num_bags):
37
            self.embedding_bags[f"embedding_bag_{i}"] = nn.EmbeddingBag(
38
                num_embeddings_per_bag, num_dims, mode="sum"
39
            )
40

41
    def forward(self, inputs):
42
        outputs = []
43
        for bag in self.embedding_bags.values():
44
            outputs.append(bag(inputs))
45
        return torch.cat(outputs)
46

47

48
# a simple sharded version of EBC
49
class CustomShardedEBC(nn.Module):
50
    def __init__(self, ebc, split_idx, specs):
51
        super().__init__()
52
        self.split_idx = split_idx
53
        row_spec, col_spec = specs
54

55
        # create embedding bags base on the spec
56
        self.embedding_bags: nn.ModuleDict = nn.ModuleDict()
57

58
        assert self.split_idx < ebc.num_bags
59
        for i in range(ebc.num_bags):
60
            bag_key = f"embedding_bag_{i}"
61
            if i < self.split_idx:
62
                shard_module(
63
                    ebc,
64
                    plan=ShardingPlan(
65
                        plan={f"embedding_bags.{bag_key}.weight": row_spec}
66
                    ),
67
                )
68
            else:
69
                shard_module(
70
                    ebc,
71
                    plan=ShardingPlan(
72
                        plan={f"embedding_bags.{bag_key}.weight": col_spec}
73
                    ),
74
                )
75

76
            self.embedding_bags[bag_key] = ebc.embedding_bags[bag_key]
77

78

79
class CustomSharder(Sharder):
80
    def __init__(self, devices, split_sharding_idx):
81
        self.devices = devices
82
        self.split_sharding_idx = split_sharding_idx
83
        self.rowwise_spec = ChunkShardingSpec(dim=0, placements=devices)
84
        self.colwise_spec = ChunkShardingSpec(dim=1, placements=devices)
85

86
    def shard(self, ebc: nn.Module) -> nn.Module:
87
        if not isinstance(ebc, CustomEmbeddingBagCollection):
88
            raise RuntimeError(
89
                "The custom sharder only supports CustomEmbeddingBagCollection"
90
            )
91

92
        return CustomShardedEBC(
93
            ebc, self.split_sharding_idx, (self.rowwise_spec, self.colwise_spec)
94
        )
95

96

97
class TestCustomSharder(ShardedTensorTestBase):
98
    @with_comms(init_rpc=False)
99
    @skip_if_lt_x_gpu(TEST_GPU_NUM)
100
    @requires_nccl()
101
    def test_custom_sharder(self):
102
        class MyModule(nn.Module):
103
            def __init__(self) -> None:
104
                super().__init__()
105
                self.ebc = CustomEmbeddingBagCollection(10, 10, 8)
106

107
            def forward(self, inputs):
108
                return self.ebc(inputs)
109

110
        custom_sharder = CustomSharder(
111
            devices=[f"rank:{i}/cuda:{i}" for i in range(TEST_GPU_NUM)],
112
            split_sharding_idx=TEST_GPU_NUM // 2,
113
        )
114

115
        sharding_plan = ShardingPlan(
116
            plan={
117
                "ebc": custom_sharder,
118
            }
119
        )
120

121
        local_model = MyModule().cuda(self.rank)
122
        sharded_model = copy.deepcopy(local_model)
123

124
        # shard the module with the provided sharding plan
125
        shard_module(sharded_model, sharding_plan)
126

127
        # check to make sure the module already been sharded
128
        emb_bags = sharded_model.ebc.embedding_bags
129
        self.assertTrue(isinstance(emb_bags["embedding_bag_0"].weight, ShardedTensor))
130
        self.assertTrue(isinstance(emb_bags["embedding_bag_9"].weight, ShardedTensor))
131
        self.assertEqual(
132
            emb_bags["embedding_bag_0"].weight.sharding_spec(),
133
            custom_sharder.rowwise_spec,
134
        )
135
        self.assertEqual(
136
            emb_bags["embedding_bag_9"].weight.sharding_spec(),
137
            custom_sharder.colwise_spec,
138
        )
139

140
        # make sure we can run sharded computation and compare outputs
141
        # with the local model version
142
        input = torch.arange(8).reshape((2, 4)).cuda(self.rank)
143
        local_output = local_model(input)
144
        sharded_output = sharded_model(input)
145

146
        self.assertEqual(local_output, sharded_output)
147

148
    @with_comms(init_rpc=False)
149
    @skip_if_lt_x_gpu(TEST_GPU_NUM)
150
    @requires_nccl()
151
    def test_custom_sharder_errors(self):
152
        custom_sharder = CustomSharder(
153
            devices=[f"rank:{i}/cuda:{i}" for i in range(TEST_GPU_NUM)],
154
            split_sharding_idx=TEST_GPU_NUM // 2,
155
        )
156

157
        sharding_plan = ShardingPlan(
158
            plan={
159
                "": custom_sharder,
160
            }
161
        )
162

163
        sharded_model = CustomEmbeddingBagCollection(10, 10, 8).cuda(self.rank)
164

165
        with self.assertRaisesRegex(
166
            KeyError, "path must not be empty for custom sharder!"
167
        ):
168
            # shard the module with the provided sharding plan
169
            shard_module(sharded_model, sharding_plan)
170

171
        # test conflicted sharding plan
172
        spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:0", "rank:1/cuda:1"])
173
        sharding_plan = ShardingPlan(
174
            plan={
175
                "embedding_bags.embedding_bag_0.weight": spec,
176
                "embedding_bags": custom_sharder,
177
            }
178
        )
179

180
        with self.assertRaisesRegex(
181
            RuntimeError, "should not conflict with the submodule tree"
182
        ):
183
            # shard the module with the provided sharding plan
184
            shard_module(sharded_model, sharding_plan)
185

186

187
if __name__ == "__main__":
188
    run_tests()
189

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.