pytorch

Форк
0
/
test_fsdp_freezing_weights.py 
250 строк · 7.2 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import contextlib
4
import sys
5
from enum import Enum
6

7
import torch
8
import torch.nn as nn
9
import torch.optim as optim
10
from torch import distributed as dist
11
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
12
from torch.nn.parallel import DistributedDataParallel
13
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
14
from torch.testing._internal.common_fsdp import FSDPTest, get_full_params
15
from torch.testing._internal.common_utils import (
16
    instantiate_parametrized_tests,
17
    parametrize,
18
    run_tests,
19
    TEST_WITH_DEV_DBG_ASAN,
20
)
21

22
if not dist.is_available():
23
    print("Distributed not available, skipping tests", file=sys.stderr)
24
    sys.exit(0)
25

26
if TEST_WITH_DEV_DBG_ASAN:
27
    print(
28
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
29
        file=sys.stderr,
30
    )
31
    sys.exit(0)
32

33

34
class Model(nn.Module):
35
    def __init__(
36
        self,
37
        with_fsdp,
38
        freeze_after_wrap_fsdp,
39
        disable_autograd,
40
        fsdp_kwargs,
41
    ):
42
        super().__init__()
43
        self.trunk = nn.Sequential(
44
            nn.Conv2d(3, 64, kernel_size=3),
45
            nn.ReLU(inplace=True),
46
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
47
            nn.Flatten(),
48
        )
49
        self.device = torch.cuda.current_device()
50
        self.head = nn.Linear(64, 10)
51
        if with_fsdp and freeze_after_wrap_fsdp:
52
            self.fsdp_wrap(fsdp_kwargs)
53
        self.autograd_ctx = (
54
            torch.no_grad if disable_autograd else contextlib.nullcontext
55
        )
56

57
    def fsdp_wrap(self, fsdp_kwargs):
58
        self.trunk = FSDP(self.trunk, **fsdp_kwargs)
59
        self.head = FSDP(self.head, **fsdp_kwargs)
60

61
    def forward(self, x):
62
        with self.autograd_ctx():
63
            x = self.trunk(x)
64
        return self.head(x)
65

66

67
class NestedTrunkModel(nn.Module):
68
    def __init__(
69
        self,
70
        with_fsdp,
71
        freeze_after_wrap_fsdp,
72
        disable_autograd,
73
        fsdp_kwargs,
74
    ):
75
        super().__init__()
76
        self.trunk = nn.Sequential(
77
            self._create_block(3, 64, with_fsdp, freeze_after_wrap_fsdp),
78
            self._create_block(64, 64, with_fsdp, freeze_after_wrap_fsdp),
79
        )
80
        self.head = nn.Sequential(
81
            nn.AdaptiveAvgPool2d(output_size=(1, 1)),
82
            nn.Flatten(),
83
            nn.Linear(64, 10),
84
        )
85
        if with_fsdp and freeze_after_wrap_fsdp:
86
            self.fsdp_wrap(fsdp_kwargs)
87
        self.autograd_ctx = (
88
            torch.no_grad if disable_autograd else contextlib.nullcontext
89
        )
90

91
    def fsdp_wrap(self, fsdp_kwargs):
92
        for name, child in self.trunk.named_children():
93
            wrapped_child = FSDP(child, **fsdp_kwargs)
94
            setattr(self.trunk, name, wrapped_child)
95
        self.trunk = FSDP(self.trunk, **fsdp_kwargs)
96
        self.head = FSDP(self.head, **fsdp_kwargs)
97

98
    def forward(self, x):
99
        with self.autograd_ctx():
100
            x = self.trunk(x)
101
        return self.head(x)
102

103
    def _create_block(
104
        self, in_channels, out_channels, with_fsdp, freeze_after_wrap_fsdp
105
    ):
106
        block = nn.Sequential(
107
            nn.Conv2d(in_channels, out_channels, kernel_size=3),
108
            nn.ReLU(inplace=True),
109
        )
110
        return block
111

112

113
class FreezingMethod(str, Enum):
114
    GradToNone = "grad_to_none"
115
    RequiresGrad = "requires_grad"
116

117

118
class TestFreezingWeights(FSDPTest):
119
    def _create_model(
120
        self,
121
        with_fsdp,
122
        with_nested_trunk,
123
        freeze_after_wrap_fsdp,
124
        disable_autograd,
125
        fsdp_kwargs,
126
    ):
127
        if with_nested_trunk:
128
            model = NestedTrunkModel(
129
                with_fsdp, freeze_after_wrap_fsdp, disable_autograd, fsdp_kwargs
130
            )
131
        else:
132
            model = Model(
133
                with_fsdp, freeze_after_wrap_fsdp, disable_autograd, fsdp_kwargs
134
            )
135
        return model
136

137
    def _dist_train(
138
        self,
139
        with_nested_trunk,
140
        freezing_method,
141
        freeze_after_wrap_fsdp,
142
        with_fsdp,
143
        disable_autograd,
144
        forward_prefetch,
145
    ):
146
        torch.manual_seed(0)
147
        batch = torch.randn(size=(2, 3, 224, 224)).cuda()
148

149
        fsdp_kwargs = {
150
            "device_id": self.rank,
151
            "forward_prefetch": forward_prefetch,
152
        }
153

154
        ddp_kwargs = {
155
            "device_ids": [self.rank],
156
            "find_unused_parameters": True if disable_autograd else False,
157
        }
158

159
        model = self._create_model(
160
            with_fsdp,
161
            with_nested_trunk,
162
            freeze_after_wrap_fsdp,
163
            disable_autograd,
164
            fsdp_kwargs,
165
        )
166
        model = model.cuda()
167

168
        # freezing the trunk using requires_grad.
169
        if freezing_method == FreezingMethod.RequiresGrad:
170
            for param in model.trunk.parameters():
171
                param.requires_grad = False
172

173
        if with_fsdp:
174
            if not freeze_after_wrap_fsdp:
175
                model.fsdp_wrap(fsdp_kwargs)
176
            model = FSDP(model, **fsdp_kwargs)
177
        else:
178
            model = DistributedDataParallel(model, **ddp_kwargs)
179

180
        target = torch.tensor([0, 1], dtype=torch.long).cuda()
181
        criterion = nn.CrossEntropyLoss()
182
        optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
183

184
        for iteration in range(3):
185
            out = model(batch)
186
            fake_loss = criterion(out, target)
187
            optimizer.zero_grad()
188
            fake_loss.backward()
189
            if freezing_method == FreezingMethod.GradToNone:
190
                for param in model.module.trunk.parameters():
191
                    param.grad = None
192
            optimizer.step()
193

194
        if with_fsdp:
195
            return get_full_params(model)
196

197
        return list(model.parameters())
198

199
    @skip_if_lt_x_gpu(2)
200
    @parametrize("with_nested_trunk", [True, False])
201
    @parametrize(
202
        "freezing_method", [FreezingMethod.RequiresGrad, FreezingMethod.GradToNone]
203
    )
204
    @parametrize("freeze_after_wrap_fsdp", [True, False])
205
    @parametrize("disable_autograd", [True, False])
206
    @parametrize("forward_prefetch", [True, False])
207
    def test_freezing_weights(
208
        self,
209
        with_nested_trunk,
210
        freezing_method,
211
        freeze_after_wrap_fsdp,
212
        disable_autograd,
213
        forward_prefetch,
214
    ):
215
        # DDP
216
        ddp_state = self._dist_train(
217
            with_nested_trunk,
218
            freezing_method,
219
            freeze_after_wrap_fsdp,
220
            with_fsdp=False,
221
            disable_autograd=disable_autograd,
222
            forward_prefetch=False,  # does not apply to DDP
223
        )
224

225
        # FSDP
226
        fsdp_state = self._dist_train(
227
            with_nested_trunk,
228
            freezing_method,
229
            freeze_after_wrap_fsdp,
230
            with_fsdp=True,
231
            disable_autograd=disable_autograd,
232
            forward_prefetch=forward_prefetch,
233
        )
234

235
        self.assertEqual(
236
            ddp_state,
237
            fsdp_state,
238
            exact_device=True,
239
            msg="FullyShardedDataParallel states didn't match PyTorch DDP states",
240
        )
241

242
        if freezing_method == FreezingMethod.RequiresGrad:
243
            for ddp_param, fsdp_param in zip(ddp_state, fsdp_state):
244
                self.assertEqual(ddp_param.requires_grad, fsdp_param.requires_grad)
245

246

247
instantiate_parametrized_tests(TestFreezingWeights)
248

249
if __name__ == "__main__":
250
    run_tests()
251

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.