pytorch

Форк
0
/
test_sharded_optim.py 
174 строки · 5.8 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
from copy import deepcopy
4

5
import torch
6
import torch.optim as optim
7
from torch.distributed._shard import shard_parameter, sharded_tensor
8
from torch.distributed._shard.sharded_optim import ShardedOptimizer
9
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
10
from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
11
from torch.testing._internal.common_utils import run_tests
12
from torch.testing._internal.distributed._shard.sharded_tensor import (
13
    ShardedTensorTestBase,
14
    with_comms,
15
)
16

17

18
class MyShardedModel(torch.nn.Module):
19
    def __init__(self, spec=None, group=None):
20
        super().__init__()
21
        # Use same seed.
22
        torch.manual_seed(0)
23
        self.param = torch.nn.Parameter(torch.rand(5, 10))
24
        if spec is not None:
25
            self.sharded_param = torch.nn.Parameter(
26
                sharded_tensor.rand(
27
                    spec, 20, 10, requires_grad=True, process_group=group
28
                )
29
            )
30
        else:
31
            self.sharded_param = torch.nn.Parameter(torch.rand(5, 10))
32

33
    def forward(self, input):
34
        if isinstance(self.sharded_param, sharded_tensor.ShardedTensor):
35
            return self.param + self.sharded_param.local_shards()[0].tensor + input
36
        else:
37
            return self.sharded_param + self.param + input
38

39

40
class MyShardedLinear(torch.nn.Module):
41
    def __init__(self, rank=None):
42
        super().__init__()
43
        # Use same seed.
44
        torch.manual_seed(0)
45
        self.linear1 = torch.nn.Linear(17, 12)
46
        self.linear2 = torch.nn.Linear(12, 29)
47
        self.gelu = torch.nn.GELU()
48

49
        if rank:
50
            self.linear1.cuda(rank)
51
            self.linear2.cuda(rank)
52

53
    def shard_parameter(self):
54
        rowwise_sharding_spec = ChunkShardingSpec(
55
            dim=0,
56
            placements=[
57
                "rank:0/cuda:0",
58
                "rank:1/cuda:1",
59
                "rank:2/cuda:2",
60
                "rank:3/cuda:3",
61
            ],
62
        )
63

64
        colwise_sharding_spec = ChunkShardingSpec(
65
            dim=1,
66
            placements=[
67
                "rank:0/cuda:0",
68
                "rank:1/cuda:1",
69
                "rank:2/cuda:2",
70
                "rank:3/cuda:3",
71
            ],
72
        )
73

74
        shard_parameter(self.linear1, "weight", rowwise_sharding_spec)
75
        shard_parameter(self.linear2, "weight", colwise_sharding_spec)
76

77
    def forward(self, inp):
78
        return self.linear2(self.gelu(self.linear1(inp)))
79

80

81
class TestShardedOptimizer(ShardedTensorTestBase):
82
    @with_comms(init_rpc=False)
83
    @skip_if_lt_x_gpu(4)
84
    @requires_nccl()
85
    def test_sharded_optim(self):
86
        rowwise_spec = ChunkShardingSpec(
87
            dim=0,
88
            placements=[
89
                "rank:0/cuda:0",
90
                "rank:1/cuda:1",
91
                "rank:2/cuda:2",
92
                "rank:3/cuda:3",
93
            ],
94
        )
95
        local_model = MyShardedModel().cuda()
96
        sharded_model = MyShardedModel(spec=rowwise_spec).cuda()
97

98
        # copy the parameters from local model
99
        sharded_model.sharded_param.local_shards()[0].tensor = (
100
            local_model.sharded_param.detach().clone().requires_grad_()
101
        )
102

103
        local_optim = optim.SGD(local_model.parameters(), lr=0.1)
104
        sharded_model_params = dict(sharded_model.named_parameters())
105
        sharded_optim = ShardedOptimizer(sharded_model_params, optim.SGD, lr=0.1)
106

107
        local_optim.zero_grad()
108
        sharded_optim.zero_grad()
109

110
        before_update = deepcopy(sharded_optim.named_params)
111

112
        inp = torch.rand([5, 10]).cuda(self.rank).requires_grad_()
113

114
        # run forward
115
        local_output = local_model(inp)
116
        sharded_output = sharded_model(inp)
117
        # backward
118
        local_output.sum().backward()
119
        sharded_output.sum().backward()
120

121
        # optimizer update
122
        local_optim.step()
123
        sharded_optim.step()
124

125
        # make sure the parameters (including sharded param)
126
        # get updated by the optimizer, and the updated
127
        # local params are the same as the sharded params
128
        for key, val in before_update.items():
129
            new_val = sharded_optim.named_params[key]
130
            if isinstance(val, sharded_tensor.ShardedTensor):
131
                self.assertNotEqual(
132
                    val.local_shards()[0].tensor, new_val.local_shards()[0].tensor
133
                )
134
                self.assertEqual(
135
                    new_val.local_shards()[0].tensor, local_model.sharded_param
136
                )
137
            else:
138
                self.assertNotEqual(val, new_val)
139
                self.assertEqual(new_val, local_model.param)
140

141
    @with_comms(init_rpc=False)
142
    @skip_if_lt_x_gpu(4)
143
    @requires_nccl()
144
    def test_named_params_with_sharded_tensor(self):
145
        rowwise_spec = ChunkShardingSpec(
146
            dim=0,
147
            placements=[
148
                "rank:0/cuda:0",
149
                "rank:1/cuda:1",
150
                "rank:2/cuda:2",
151
                "rank:3/cuda:3",
152
            ],
153
        )
154
        sharded_model = MyShardedModel(spec=rowwise_spec).cuda()
155
        sharded_model_params = dict(sharded_model.named_parameters())
156
        param_keys = list(sharded_model_params.keys())
157
        self.assertEqual(len(param_keys), 2)
158
        self.assertTrue("param" in param_keys)
159
        self.assertTrue("sharded_param" in param_keys)
160

161
        sharded_linear = MyShardedLinear(rank=self.rank).cuda()
162
        sharded_linear.shard_parameter()
163
        sharded_linear_params = dict(sharded_linear.named_parameters())
164
        param_keys = list(sharded_linear_params.keys())
165
        self.assertEqual(len(param_keys), 4)
166
        self.assertTrue("linear1.bias" in param_keys)
167
        self.assertTrue("linear2.bias" in param_keys)
168
        self.assertTrue("linear1.weight" in param_keys)
169
        self.assertTrue("linear2.weight" in param_keys)
170
        self.assertFalse("bias" in param_keys)
171

172

173
if __name__ == "__main__":
174
    run_tests()
175

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.