pytorch

Форк
0
/
test_fsdp2_mem_tracker.py 
270 строк · 9.3 Кб
1
# Owner(s): ["module: unknown"]
2
import functools
3
import gc
4
from typing import Union
5

6
import torch
7
import torch.nn as nn
8
from torch.distributed._composable import checkpoint
9
from torch.distributed._composable.fsdp import (
10
    CPUOffloadPolicy,
11
    fully_shard,
12
    MixedPrecisionPolicy,
13
    OffloadPolicy,
14
)
15
from torch.distributed._tensor import init_device_mesh
16
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
17
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
18
    apply_activation_checkpointing,
19
    CheckpointWrapper,
20
)
21
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
22
from torch.testing._internal.common_fsdp import FSDPTest, MLP
23
from torch.testing._internal.common_utils import run_tests
24
from torch.testing._internal.distributed._tensor.common_dtensor import (
25
    ModelArgs,
26
    Transformer,
27
    TransformerBlock,
28
)
29

30

31
def _init_cublas_workspace(dev: torch.device):
32
    lin = torch.nn.Linear(768, 768, device=dev)
33
    inp = torch.randn(1, 768, device=dev)
34
    lin(inp).sum().backward()
35
    del lin
36
    del inp
37

38

39
def _reset_mem_stats(dev: torch.device):
40
    torch.cuda.empty_cache()
41
    torch.cuda.reset_accumulated_memory_stats(dev)
42
    torch.cuda.reset_peak_memory_stats(dev)
43

44

45
class TestTrackerFullyShard1DTrainingCore(FSDPTest):
46
    @property
47
    def world_size(self) -> int:
48
        return min(4, torch.cuda.device_count())
49

50
    @skip_if_lt_x_gpu(2)
51
    def test_tracker_multi_group_eager(self):
52
        """
53
        Tests tracker accuracy when using multiple parameter groups for
54
        communication (for communication and computation overlap plus memory
55
        reduction) and different mixed precision policies.
56
        """
57
        self.run_subtests(
58
            {
59
                "reshard_after_forward": [True, False],
60
                "offload_policy": [
61
                    CPUOffloadPolicy(pin_memory=False),
62
                    OffloadPolicy(),
63
                ],
64
                "mp_policy": [
65
                    MixedPrecisionPolicy(
66
                        param_dtype=torch.float16, reduce_dtype=torch.float32
67
                    ),
68
                ],
69
            },
70
            self._test_tracker_multi_group,
71
        )
72

73
    def _test_tracker_multi_group(
74
        self,
75
        reshard_after_forward: Union[bool, int],
76
        offload_policy: OffloadPolicy,
77
        mp_policy: MixedPrecisionPolicy,
78
    ):
79
        debug = False
80
        dev = torch.device(torch.cuda.current_device())
81
        _init_cublas_workspace(dev)
82
        gc.collect()
83
        _reset_mem_stats(dev)
84
        mem_stats = torch.cuda.memory_stats(dev)
85
        pre_cuda_active = mem_stats["active_bytes.all.current"]
86
        torch.manual_seed(42)
87
        lin_dim, bsz = 2048, 8192
88
        with torch.device(dev):
89
            model = nn.Sequential(*[MLP(dim=lin_dim, device=dev) for _ in range(4)])
90
        mesh = init_device_mesh("cuda", (self.world_size,))
91
        fully_shard_fn = functools.partial(
92
            fully_shard,
93
            mesh=mesh,
94
            reshard_after_forward=reshard_after_forward,
95
            offload_policy=offload_policy,
96
            mp_policy=mp_policy,
97
        )
98
        for mlp in model:
99
            fully_shard_fn(mlp)
100
        fully_shard_fn(model)
101
        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
102
        inp = torch.randn((bsz, lin_dim), device=dev)
103
        fmt = FSDPMemTracker(model, optim)
104
        fmt.track_inputs((inp,))
105
        with fmt:
106
            for iter_idx in range(2):
107
                loss = model(inp).sum()
108
                loss.backward()
109
                optim.step()
110
                optim.zero_grad()
111
                if iter_idx == 0:
112
                    fmt.reset_mod_stats()
113
        mem_stats = torch.cuda.memory_stats()
114
        tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"]
115
        cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active
116
        accuracy = tracker_max / cuda_max
117
        if self.rank == 0 and debug:
118
            print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}")
119
        self.assertAlmostEqual(
120
            accuracy,
121
            1.0,
122
            delta=0.1,
123
            msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}",
124
        )
125
        del model
126
        del inp
127
        del optim
128

129
    @skip_if_lt_x_gpu(2)
130
    def test_tracker_non_root_forward_backward(self):
131
        """
132
        Tests tracker accracy when running forward/backward through a non-root.
133
        """
134
        debug = False
135
        dev = torch.device(torch.cuda.current_device())
136
        _init_cublas_workspace(dev)
137
        gc.collect()
138
        _reset_mem_stats(dev)
139
        mem_stats = torch.cuda.memory_stats(dev)
140
        pre_cuda_active = mem_stats["active_bytes.all.current"]
141
        torch.manual_seed(42)
142
        lin_dim, bsz = 2048, 8
143
        model = nn.Sequential(*[MLP(lin_dim, dev) for _ in range(3)])
144
        for mlp in model:
145
            fully_shard(mlp)
146
        fully_shard(model)
147
        optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
148
        torch.manual_seed(42 + self.rank)
149
        inp = torch.randn((bsz, lin_dim), device=dev)
150
        fmt = FSDPMemTracker(model, optim)
151
        fmt.track_inputs((inp,))
152
        with fmt:
153
            for iter_idx in range(2):
154
                nonroot_loss = model[0](inp).sum()
155
                nonroot_loss.backward()
156
                optim.step()
157
                optim.zero_grad()
158
                if iter_idx == 0:
159
                    fmt.reset_mod_stats()
160
        mem_stats = torch.cuda.memory_stats()
161
        tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"]
162
        cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active
163
        accuracy = tracker_max / cuda_max
164
        if self.rank == 0 and debug:
165
            print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}")
166
        self.assertAlmostEqual(
167
            accuracy,
168
            1.0,
169
            delta=0.1,
170
            msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}",
171
        )
172
        del inp
173
        del model
174
        del optim
175

176

177
class TestTrackerFullyShard1DTrainingCompose(FSDPTest):
178
    @property
179
    def world_size(self) -> int:
180
        return min(torch.cuda.device_count(), 4)
181

182
    @skip_if_lt_x_gpu(2)
183
    def test_tracker_with_activation_checkpointing(self):
184
        """
185
        Tests tracker accuracy when composing with activation checkpointing.
186
        """
187
        self.run_subtests(
188
            {
189
                "reshard_after_forward": [True, False],
190
                "checkpoint_impl": ["composable", "wrapper"],
191
            },
192
            self._test_tracker_with_activation_checkpointing,
193
        )
194

195
    def _test_tracker_with_activation_checkpointing(
196
        self, reshard_after_forward: Union[bool, int], checkpoint_impl: str
197
    ):
198
        assert checkpoint_impl in ("composable", "wrapper")
199
        debug = False
200
        dev = torch.device(torch.cuda.current_device())
201
        _init_cublas_workspace(dev)
202
        gc.collect()
203
        _reset_mem_stats(dev)
204
        mem_stats = torch.cuda.memory_stats(dev)
205
        pre_cuda_active = mem_stats["active_bytes.all.current"]
206
        torch.manual_seed(42)
207
        vocab_size = 8192
208
        bsz, seq_len = 16, 512
209
        with torch.device(dev):
210
            model_args = ModelArgs(
211
                n_layers=4,
212
                n_heads=4,
213
                vocab_size=vocab_size,
214
                max_seq_len=seq_len,
215
                dropout_p=0.1,
216
            )
217
            model = Transformer(model_args)
218
        foreach = False
219
        fully_shard_fn = functools.partial(
220
            fully_shard,
221
            reshard_after_forward=reshard_after_forward,
222
        )
223
        if checkpoint_impl == "wrapper":
224
            apply_activation_checkpointing(
225
                model, check_fn=lambda m: isinstance(m, TransformerBlock)
226
            )
227
            for module in model.modules():
228
                # Apply to `CheckpointWrapper`, which wraps `TransformerBlock`
229
                if isinstance(module, CheckpointWrapper):
230
                    fully_shard_fn(module)
231
        else:
232
            for module in model.modules():
233
                if isinstance(module, TransformerBlock):
234
                    if checkpoint_impl == "composable":
235
                        checkpoint(module)
236
                    fully_shard_fn(module)
237
        fully_shard_fn(model)
238
        optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=foreach)
239

240
        torch.manual_seed(42 + self.rank)
241
        inp = torch.randint(0, vocab_size, (bsz, seq_len), device=dev)
242
        fmt = FSDPMemTracker(model, optim)
243
        fmt.track_inputs((inp,))
244
        with fmt:
245
            for iter_idx in range(2):
246
                loss = model(inp).sum()
247
                loss.backward()
248
                optim.step()
249
                optim.zero_grad()
250
                if iter_idx == 0:
251
                    fmt.reset_mod_stats()
252
        mem_stats = torch.cuda.memory_stats()
253
        tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"]
254
        cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active
255
        accuracy = tracker_max / cuda_max
256
        if self.rank == 0 and debug:
257
            print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}")
258
        self.assertAlmostEqual(
259
            accuracy,
260
            1.0,
261
            delta=0.1,
262
            msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}",
263
        )
264
        del inp
265
        del model
266
        del optim
267

268

269
if __name__ == "__main__":
270
    run_tests()
271

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.