pytorch

Форк
0
/
test_fsdp_meta.py 
424 строки · 14.3 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import itertools
4
import sys
5

6
from typing import Union
7

8
import torch
9
import torch.distributed as dist
10
import torch.nn as nn
11
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision
12
from torch.distributed.fsdp.wrap import (
13
    always_wrap_policy as always_wrap,
14
    enable_wrap,
15
    ModuleWrapPolicy,
16
    wrap,
17
)
18
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
19
from torch.testing._internal.common_fsdp import FSDPTest
20
from torch.testing._internal.common_utils import (
21
    instantiate_parametrized_tests,
22
    parametrize,
23
    run_tests,
24
    skip_but_pass_in_sandcastle_if,
25
    TEST_WITH_DEV_DBG_ASAN,
26
)
27

28
_TORCHDISTX_AVAIL = True
29
try:
30
    from torchdistx import deferred_init
31
except ImportError:
32
    _TORCHDISTX_AVAIL = False
33

34

35
if not dist.is_available():
36
    print("Distributed not available, skipping tests", file=sys.stderr)
37
    sys.exit(0)
38

39
if TEST_WITH_DEV_DBG_ASAN:
40
    print(
41
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
42
        file=sys.stderr,
43
    )
44
    sys.exit(0)
45

46

47
def _reset_params_if_meta(is_meta: bool, model: nn.Module):
48
    # For torchdistX init, we don't need to call reset_params, as
49
    # deferred_init(model).materialize() is equivalent to model().
50
    if is_meta:
51
        for module in model.modules():
52
            # Assume that a module has `reset_parameters()` iff it has directly
53
            # managed parameters or buffers
54
            if hasattr(module, "reset_parameters"):
55
                module.reset_parameters()
56

57

58
class MyLinear(nn.Linear):
59
    """
60
    Linear layer with deterministic reset_parameters for testing.
61
    """
62

63
    def __init__(self, *args, **kwargs):
64
        super().__init__(*args, **kwargs)
65

66
    def reset_parameters(self, *args, **kwargs):
67
        torch.manual_seed(42)
68
        with torch.no_grad():
69
            # Use an initialization method that depends on shape
70
            torch.nn.init.xavier_uniform_(self.weight, 1.0)
71

72

73
class MyBuffer(nn.Module):
74
    def __init__(self, device: torch.device):
75
        super().__init__()
76
        self.register_buffer("buf", torch.empty((3, 3), device=device))
77

78
    def reset_parameters(self, *args, **kwargs):
79
        torch.manual_seed(42)
80
        # Use an initialization method that depends on shape
81
        torch.nn.init.xavier_uniform_(self.buf, 0.5)
82

83

84
class MyModel(nn.Module):
85
    def __init__(self, device: torch.device):
86
        super().__init__()
87
        self.lin1 = MyLinear(2, 2, bias=False, device=device)
88
        self.lin2 = MyLinear(2, 2, bias=False, device=device)
89
        self.buf_mod = MyBuffer(device)
90

91
    def forward(self, x):
92
        return self.lin2(self.lin1(x))
93

94

95
class NestedModel(nn.Module):
96
    def __init__(self, device):
97
        super().__init__()
98
        self.lin1 = MyLinear(2, 2, bias=False, device=device)
99
        self.lin1 = wrap(self.lin1)
100
        self.lin2 = MyLinear(2, 2, bias=False, device=device)
101
        self.l3 = MyModel(device=device)
102
        self.l3 = wrap(self.l3)
103

104
    def forward(self, x):
105
        return self.l3(self.lin2(self.lin1(x)))
106

107

108
def _init_with_reset_params(module: nn.Module):
109
    """
110
    to_empty + reset_parameters() init function example for modules
111
    initialized with device="meta"
112
    """
113
    has_meta_states = any(
114
        t.is_meta
115
        for t in itertools.chain(
116
            module.parameters(recurse=False), module.buffers(recurse=False)
117
        )
118
    )
119
    if has_meta_states:
120
        device = torch.device("cuda", torch.cuda.current_device())
121
        module.to_empty(device=device, recurse=False)
122
        module.reset_parameters()
123

124

125
def _init_with_torchdistX(module: nn.Module):
126
    """
127
    torchdistX-based deferred module initialization function example
128
    using ``materialize_module``.
129
    """
130
    assert _TORCHDISTX_AVAIL
131

132
    def check_fn(k):
133
        return not isinstance(k, FSDP)
134

135
    deferred_init.materialize_module(module, check_fn=check_fn)
136

137

138
class TestFSDPWithMetaDevice(FSDPTest):
139
    @property
140
    def world_size(self):
141
        return 2
142

143
    @property
144
    def process_group(self):
145
        return dist.distributed_c10d._get_default_group()
146

147
    def _compare_fsdp(self, fsdp1, fsdp2):
148
        with FSDP.summon_full_params(fsdp1):
149
            with FSDP.summon_full_params(fsdp2):
150
                for p1, p2 in zip(fsdp1.parameters(), fsdp2.parameters()):
151
                    self.assertTrue(torch.allclose(p1, p2), f"{p1} vs {p2}")
152

153
    def _test_simple_model_with_meta_device(self, meta_module_fn, init_fn=None):
154
        # Create model on meta device and wrap with FSDP.
155
        model = meta_module_fn()
156
        is_meta = next(model.parameters()).is_meta
157
        fsdp_meta = FSDP(
158
            model,
159
            auto_wrap_policy=always_wrap,
160
            param_init_fn=init_fn,
161
        )
162

163
        meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3)
164

165
        # Test to make sure it is the same model parameters as regular FSDP
166
        # approach.
167
        regular = MyModel(device="cuda")
168
        _reset_params_if_meta(is_meta, regular)
169
        fsdp_regular = FSDP(regular, auto_wrap_policy=always_wrap)
170
        regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3)
171

172
        self._compare_fsdp(fsdp_meta, fsdp_regular)
173
        inp = torch.randn(10, 2, device="cuda")
174
        fsdp_meta(inp).sum().backward()
175
        fsdp_regular(inp).sum().backward()
176
        meta_opt.step()
177
        regular_opt.step()
178
        self._compare_fsdp(fsdp_meta, fsdp_regular)
179

180
        # Test that meta init works if all submodules are contained in only a
181
        # single FSDP unit.
182
        model = meta_module_fn()
183
        fsdp_meta = FSDP(model, param_init_fn=init_fn)
184
        meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3)
185
        regular = MyModel(device="cuda")
186
        _reset_params_if_meta(is_meta, regular)
187
        fsdp_regular = FSDP(regular, auto_wrap_policy=always_wrap)
188
        regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3)
189

190
        # Run a forward + backward pass + optimizer step
191
        fsdp_meta(inp).sum().backward()
192
        fsdp_regular(inp).sum().backward()
193
        meta_opt.step()
194
        regular_opt.step()
195
        self._compare_fsdp(fsdp_meta, fsdp_regular)
196

197
    @skip_if_lt_x_gpu(2)
198
    def test_simple_model_with_meta_device_reset_params(self):
199
        def meta_module_fn():
200
            return MyModel(device="meta")
201

202
        self._test_simple_model_with_meta_device(
203
            meta_module_fn, _init_with_reset_params
204
        )
205

206
    @skip_if_lt_x_gpu(2)
207
    def test_simple_model_with_meta_device_default_init(self):
208
        def meta_module_fn():
209
            return MyModel(device="meta")
210

211
        self._test_simple_model_with_meta_device(meta_module_fn)
212

213
    @skip_if_lt_x_gpu(2)
214
    @skip_but_pass_in_sandcastle_if(
215
        not _TORCHDISTX_AVAIL,
216
        "Test requires torchdistX: https://github.com/pytorch/torchdistX",
217
    )
218
    def test_simple_model_with_torchdistX_default_init(self):
219
        def meta_module_fn():
220
            return deferred_init.deferred_init(MyModel, device="cuda")
221

222
        self._test_simple_model_with_meta_device(meta_module_fn)
223

224
    @skip_if_lt_x_gpu(2)
225
    @skip_but_pass_in_sandcastle_if(
226
        not _TORCHDISTX_AVAIL,
227
        "Test requires torchdistX: https://github.com/pytorch/torchdistX",
228
    )
229
    def test_simple_model_with_torchdistX_init_fn(self):
230
        def meta_module_fn():
231
            return deferred_init.deferred_init(MyModel, device="cuda")
232

233
        self._test_simple_model_with_meta_device(
234
            meta_module_fn, init_fn=_init_with_torchdistX
235
        )
236

237
    def _test_nested_model_with_meta_device(
238
        self, auto_wrap, meta_module_fn, init_fn=None
239
    ):
240
        if auto_wrap:
241
            module = meta_module_fn()
242
            is_meta = (
243
                next(module.parameters()).is_meta or next(module.buffers()).is_meta
244
            )
245
            fsdp_meta = FSDP(
246
                module,
247
                auto_wrap_policy=always_wrap,
248
                param_init_fn=init_fn,
249
            )
250
            meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3)
251
            module_regular = NestedModel(device="cuda")
252
            _reset_params_if_meta(is_meta, module_regular)
253
            fsdp_regular = FSDP(
254
                module_regular,
255
                auto_wrap_policy=always_wrap,
256
            )
257
            regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3)
258
        else:
259
            with enable_wrap(
260
                wrapper_cls=FSDP,
261
                param_init_fn=init_fn,
262
            ):
263
                module = meta_module_fn()
264
                is_meta = next(module.parameters()).is_meta
265
                # Non FSDP modules will still be initialized because they bubble up
266
                # to be part of a larger FSDP unit.
267
                fsdp_meta = wrap(module)
268
                meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3)
269

270
            # Init and reset parameters before wrapping so that reset_params
271
            # matches up with meta device's initialization.
272
            module_regular = NestedModel(device="cuda")
273
            _reset_params_if_meta(is_meta, module_regular)
274
            with enable_wrap(wrapper_cls=FSDP):
275
                module_regular.lin1 = wrap(module_regular.lin1)
276
                module_regular.l3 = wrap(module_regular.l3)
277
                fsdp_regular = wrap(module_regular)
278
                regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3)
279

280
        # Compare it before training
281
        self._compare_fsdp(fsdp_meta, fsdp_regular)
282
        inp = torch.randn(10, 2, device="cuda")
283
        fsdp_meta(inp).sum().backward()
284
        fsdp_regular(inp).sum().backward()
285
        meta_opt.step()
286
        regular_opt.step()
287
        self._compare_fsdp(fsdp_meta, fsdp_regular)
288

289
    @skip_if_lt_x_gpu(2)
290
    @parametrize("auto_wrap", [True, False])
291
    def test_nested_model_with_meta_device_reset_params(self, auto_wrap):
292
        def meta_module_fn():
293
            return NestedModel(device="meta")
294

295
        self._test_nested_model_with_meta_device(
296
            auto_wrap=auto_wrap,
297
            meta_module_fn=meta_module_fn,
298
            init_fn=_init_with_reset_params,
299
        )
300

301
    @skip_if_lt_x_gpu(2)
302
    @parametrize("auto_wrap", [True, False])
303
    def test_nested_model_with_meta_device_default_init(self, auto_wrap):
304
        def meta_module_fn():
305
            return NestedModel(device="meta")
306

307
        self._test_nested_model_with_meta_device(
308
            auto_wrap=auto_wrap,
309
            meta_module_fn=meta_module_fn,
310
        )
311

312
    @skip_if_lt_x_gpu(2)
313
    @skip_but_pass_in_sandcastle_if(
314
        not _TORCHDISTX_AVAIL,
315
        "Test requires torchdistX: https://github.com/pytorch/torchdistX",
316
    )
317
    @parametrize("auto_wrap", [True, False])
318
    def test_nested_model_with_torchdistX_default_init(self, auto_wrap):
319
        def meta_module_fn():
320
            return deferred_init.deferred_init(NestedModel, device="cuda")
321

322
        self._test_nested_model_with_meta_device(
323
            auto_wrap=auto_wrap, meta_module_fn=meta_module_fn
324
        )
325

326
    @skip_if_lt_x_gpu(2)
327
    @skip_but_pass_in_sandcastle_if(
328
        not _TORCHDISTX_AVAIL,
329
        "Test requires torchdistX: https://github.com/pytorch/torchdistX",
330
    )
331
    @parametrize("auto_wrap", [True, False])
332
    def test_nested_model_with_torchdistX_init_fn(self, auto_wrap):
333
        def meta_module_fn():
334
            return deferred_init.deferred_init(NestedModel, device="cuda")
335

336
        self._test_nested_model_with_meta_device(
337
            auto_wrap=auto_wrap,
338
            meta_module_fn=meta_module_fn,
339
            init_fn=_init_with_torchdistX,
340
        )
341

342
    def _test_bad_arg(self, meta_module_fn):
343
        mod = meta_module_fn()
344
        with self.assertRaisesRegex(ValueError, "to be callable"):
345
            FSDP(mod, param_init_fn=42)
346

347
    @skip_if_lt_x_gpu(2)
348
    @skip_but_pass_in_sandcastle_if(
349
        not _TORCHDISTX_AVAIL,
350
        "Test requires torchdistX: https://github.com/pytorch/torchdistX",
351
    )
352
    def test_bad_arg_torchdistx(self):
353
        def meta_module_fn():
354
            return deferred_init.deferred_init(NestedModel, "cuda")
355

356
        self._test_bad_arg(meta_module_fn)
357

358
    @skip_if_lt_x_gpu(2)
359
    def test_bad_arg_meta(self):
360
        def meta_module_fn():
361
            return NestedModel(device="meta")
362

363
        self._test_bad_arg(meta_module_fn)
364

365
    @skip_if_lt_x_gpu(2)
366
    def test_meta_device_with_mixed_precision(self):
367
        """
368
        Tests meta device initialization with a ``param_init_fn`` when
369
        specifying mixed precision with ``param_dtype=torch.float32``.
370
        """
371

372
        class FakeLinear(nn.Module):
373
            def __init__(
374
                self, in_dim: int, out_dim: int, device: Union[torch.device, str]
375
            ) -> None:
376
                super().__init__()
377
                self.weight = nn.Parameter(
378
                    torch.randn((in_dim, out_dim), device=device)
379
                )
380

381
            def forward(self, x: torch.Tensor) -> torch.Tensor:
382
                return x @ self.weight
383

384
        class Model(nn.Module):
385
            def __init__(self) -> None:
386
                super().__init__()
387
                self.lin1 = nn.Linear(5, 5, device="meta")
388
                self.lin2 = FakeLinear(5, 5, device="meta")
389
                self.relu = nn.ReLU()
390

391
            def forward(self, x: torch.Tensor) -> torch.Tensor:
392
                return self.lin2(self.relu(self.lin1(x)))
393

394
            def _module_init_fn(self, module: nn.Module):
395
                if isinstance(module, nn.Linear):
396
                    torch.nn.init.normal_(module.weight, mean=0.0, std=0.1)
397
                    if module.bias is not None:
398
                        torch.nn.init.zeros_(module.bias)
399

400
        def _param_init_fn(module: nn.Module) -> None:
401
            # TODO: `module.to_empty()` is not generally correct for meta
402
            # device initialization.
403
            # https://github.com/pytorch/pytorch/issues/90465
404
            module.to_empty(device=torch.device("cuda"))
405
            module.apply(model._module_init_fn)
406

407
        model = Model()
408
        # Wrap `lin1` and the top level `model` to create nested FSDP instances
409
        # where each instance has parameters
410
        FSDP(
411
            model,
412
            auto_wrap_policy=ModuleWrapPolicy({nn.Linear}),
413
            mixed_precision=MixedPrecision(
414
                param_dtype=torch.float32, reduce_dtype=torch.float16
415
            ),
416
            param_init_fn=_param_init_fn,
417
            device_id=torch.cuda.current_device(),
418
        )
419

420

421
instantiate_parametrized_tests(TestFSDPWithMetaDevice)
422

423
if __name__ == "__main__":
424
    run_tests()
425

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.