pytorch
188 строк · 6.2 Кб
1# Owner(s): ["oncall: distributed"]
2import copy
3import sys
4
5import torch
6import torch.nn as nn
7from torch.distributed._shard import shard_module
8from torch.distributed._shard.sharded_tensor import ShardedTensor
9from torch.distributed._shard.sharder import Sharder
10from torch.distributed._shard.sharding_plan import ShardingPlan
11from torch.distributed._shard.sharding_spec import ChunkShardingSpec
12from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
13from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
14from torch.testing._internal.distributed._shard.sharded_tensor import (
15ShardedTensorTestBase,
16TEST_GPU_NUM,
17with_comms,
18)
19
20
21if TEST_WITH_DEV_DBG_ASAN:
22print(
23"Skip dev-asan as torch + multiprocessing spawn have known issues",
24file=sys.stderr,
25)
26sys.exit(0)
27
28
29# a simple collection of embedding bag implementation
30class CustomEmbeddingBagCollection(nn.Module):
31def __init__(self, num_bags, num_embeddings_per_bag, num_dims):
32super().__init__()
33self.num_bags = num_bags
34self.embedding_bags: nn.ModuleDict = nn.ModuleDict()
35
36for i in range(num_bags):
37self.embedding_bags[f"embedding_bag_{i}"] = nn.EmbeddingBag(
38num_embeddings_per_bag, num_dims, mode="sum"
39)
40
41def forward(self, inputs):
42outputs = []
43for bag in self.embedding_bags.values():
44outputs.append(bag(inputs))
45return torch.cat(outputs)
46
47
48# a simple sharded version of EBC
49class CustomShardedEBC(nn.Module):
50def __init__(self, ebc, split_idx, specs):
51super().__init__()
52self.split_idx = split_idx
53row_spec, col_spec = specs
54
55# create embedding bags base on the spec
56self.embedding_bags: nn.ModuleDict = nn.ModuleDict()
57
58assert self.split_idx < ebc.num_bags
59for i in range(ebc.num_bags):
60bag_key = f"embedding_bag_{i}"
61if i < self.split_idx:
62shard_module(
63ebc,
64plan=ShardingPlan(
65plan={f"embedding_bags.{bag_key}.weight": row_spec}
66),
67)
68else:
69shard_module(
70ebc,
71plan=ShardingPlan(
72plan={f"embedding_bags.{bag_key}.weight": col_spec}
73),
74)
75
76self.embedding_bags[bag_key] = ebc.embedding_bags[bag_key]
77
78
79class CustomSharder(Sharder):
80def __init__(self, devices, split_sharding_idx):
81self.devices = devices
82self.split_sharding_idx = split_sharding_idx
83self.rowwise_spec = ChunkShardingSpec(dim=0, placements=devices)
84self.colwise_spec = ChunkShardingSpec(dim=1, placements=devices)
85
86def shard(self, ebc: nn.Module) -> nn.Module:
87if not isinstance(ebc, CustomEmbeddingBagCollection):
88raise RuntimeError(
89"The custom sharder only supports CustomEmbeddingBagCollection"
90)
91
92return CustomShardedEBC(
93ebc, self.split_sharding_idx, (self.rowwise_spec, self.colwise_spec)
94)
95
96
97class TestCustomSharder(ShardedTensorTestBase):
98@with_comms(init_rpc=False)
99@skip_if_lt_x_gpu(TEST_GPU_NUM)
100@requires_nccl()
101def test_custom_sharder(self):
102class MyModule(nn.Module):
103def __init__(self) -> None:
104super().__init__()
105self.ebc = CustomEmbeddingBagCollection(10, 10, 8)
106
107def forward(self, inputs):
108return self.ebc(inputs)
109
110custom_sharder = CustomSharder(
111devices=[f"rank:{i}/cuda:{i}" for i in range(TEST_GPU_NUM)],
112split_sharding_idx=TEST_GPU_NUM // 2,
113)
114
115sharding_plan = ShardingPlan(
116plan={
117"ebc": custom_sharder,
118}
119)
120
121local_model = MyModule().cuda(self.rank)
122sharded_model = copy.deepcopy(local_model)
123
124# shard the module with the provided sharding plan
125shard_module(sharded_model, sharding_plan)
126
127# check to make sure the module already been sharded
128emb_bags = sharded_model.ebc.embedding_bags
129self.assertTrue(isinstance(emb_bags["embedding_bag_0"].weight, ShardedTensor))
130self.assertTrue(isinstance(emb_bags["embedding_bag_9"].weight, ShardedTensor))
131self.assertEqual(
132emb_bags["embedding_bag_0"].weight.sharding_spec(),
133custom_sharder.rowwise_spec,
134)
135self.assertEqual(
136emb_bags["embedding_bag_9"].weight.sharding_spec(),
137custom_sharder.colwise_spec,
138)
139
140# make sure we can run sharded computation and compare outputs
141# with the local model version
142input = torch.arange(8).reshape((2, 4)).cuda(self.rank)
143local_output = local_model(input)
144sharded_output = sharded_model(input)
145
146self.assertEqual(local_output, sharded_output)
147
148@with_comms(init_rpc=False)
149@skip_if_lt_x_gpu(TEST_GPU_NUM)
150@requires_nccl()
151def test_custom_sharder_errors(self):
152custom_sharder = CustomSharder(
153devices=[f"rank:{i}/cuda:{i}" for i in range(TEST_GPU_NUM)],
154split_sharding_idx=TEST_GPU_NUM // 2,
155)
156
157sharding_plan = ShardingPlan(
158plan={
159"": custom_sharder,
160}
161)
162
163sharded_model = CustomEmbeddingBagCollection(10, 10, 8).cuda(self.rank)
164
165with self.assertRaisesRegex(
166KeyError, "path must not be empty for custom sharder!"
167):
168# shard the module with the provided sharding plan
169shard_module(sharded_model, sharding_plan)
170
171# test conflicted sharding plan
172spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:0", "rank:1/cuda:1"])
173sharding_plan = ShardingPlan(
174plan={
175"embedding_bags.embedding_bag_0.weight": spec,
176"embedding_bags": custom_sharder,
177}
178)
179
180with self.assertRaisesRegex(
181RuntimeError, "should not conflict with the submodule tree"
182):
183# shard the module with the provided sharding plan
184shard_module(sharded_model, sharding_plan)
185
186
187if __name__ == "__main__":
188run_tests()
189