pytorch

Форк
0
/
test_mem_tracker.py 
236 строк · 9.2 Кб
1
# Owner(s): ["module: unknown"]
2
import gc
3
import unittest
4
from typing import Tuple
5

6
import torch
7
import torch.nn as nn
8
from torch.distributed._tools.mem_tracker import MemTracker
9
from torch.testing._internal.common_cuda import TEST_CUDA
10
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
11
from torch.utils.checkpoint import checkpoint
12

13

14
class TestMemTracker(TestCase):
15
    def _init_cublas_workspace(self, dev: torch.device):
16
        lin = torch.nn.Linear(768, 768, device=dev)
17
        inp = torch.randn(1, 768, device=dev)
18
        lin(inp).sum().backward()
19
        del lin
20
        del inp
21

22
    def _reset_mem_stats(self, dev: torch.device):
23
        torch.cuda.empty_cache()
24
        torch.cuda.reset_accumulated_memory_stats(dev)
25
        torch.cuda.reset_peak_memory_stats(dev)
26

27
    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
28
    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
29
    def test_cuda_tracker_equivalence(
30
        self,
31
    ):
32
        """
33
        Tests that the tracker correctly calculates the peak memory.
34
        """
35
        dev = torch.device(torch.cuda.current_device())
36
        self._init_cublas_workspace(dev)
37
        gc.collect(1)
38
        self._reset_mem_stats(dev)
39
        mem_stats = torch.cuda.memory_stats(dev)
40
        pre_cuda_active = mem_stats["active_bytes.all.current"]
41
        bsz, n_layers, dim, dtype = 16, 4, 512, torch.bfloat16
42

43
        class DummyModel(nn.Module):
44
            def __init__(self, n_layers: int, dim: int, dtype: torch.dtype):
45
                super().__init__()
46
                self.linears = nn.ModuleList()
47
                for _ in range(n_layers):
48
                    self.linears.append(nn.Linear(dim, dim, dtype=dtype))
49
                    self.linears.append(nn.ReLU())
50

51
            def forward(self, x):
52
                for layer in self.linears:
53
                    x = layer(x)
54
                return x
55

56
        with torch.device(dev):
57
            model = DummyModel(n_layers, dim, dtype=dtype)
58
        optim = torch.optim.Adam(model.parameters(), foreach=True)
59
        input_batch = torch.randn(bsz, dim, device=dev, dtype=dtype)
60
        mem_tracker = MemTracker()
61
        mem_tracker.track_external(model, optim, input_batch)
62
        with mem_tracker as mt:
63
            for iter_idx in range(2):
64
                model(input_batch).sum().backward()
65
                optim.step()
66
                optim.zero_grad()
67
                if iter_idx == 0:
68
                    mt.reset_mod_stats()
69
        # Check for accuracy of peak memory
70

71
        tracker_max = mt.get_tracker_snapshot("peak")[dev]["Total"]
72
        mem_stats = torch.cuda.memory_stats(dev)
73
        cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active
74
        accuracy = tracker_max / cuda_max
75
        self.assertAlmostEqual(accuracy, 1.0, delta=0.1)
76

77
    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
78
    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
79
    def test_tracker_with_activation_checkpointing(
80
        self,
81
    ):
82
        """
83
        Tests that the tracker correctly computes the peak memory during activation checkpointing.
84
        """
85
        dev = torch.device(torch.cuda.current_device())
86
        self._init_cublas_workspace(dev)
87
        gc.collect(1)
88
        self._reset_mem_stats(dev)
89
        mem_stats = torch.cuda.memory_stats(dev)
90
        pre_cuda_active = mem_stats["active_bytes.all.current"]
91

92
        bsz, n_layers, dim, dtype = 128, 4, 1024, torch.float16
93

94
        class MLPBlock(nn.Module):
95
            def __init__(self, dim: int, dtype: torch.dtype):
96
                super().__init__()
97
                self.mlp_block = nn.Sequential(
98
                    nn.Linear(dim, 2 * dim, dtype=dtype),
99
                    nn.ReLU(),
100
                    nn.Linear(2 * dim, dim, dtype=dtype),
101
                )
102

103
            def forward(self, x):
104
                return self.mlp_block(x)
105

106
        class MyModule(nn.Module):
107
            def __init__(
108
                self, n_layers: int, dim: int, dtype: torch.dtype, use_ac: bool = False
109
            ):
110
                super().__init__()
111
                self.mlp_blocks = nn.ModuleList()
112
                self.use_ac = use_ac
113
                for _ in range(n_layers):
114
                    self.mlp_blocks.append(MLPBlock(dim, dtype=dtype))
115

116
            def forward(self, x):
117
                for i, block in enumerate(self.mlp_blocks):
118
                    if i >= 1 and self.use_ac:
119
                        x = checkpoint(
120
                            block, x, preserve_rng_state=True, use_reentrant=False
121
                        )
122
                    else:
123
                        x = block(x)
124
                return x
125

126
        with torch.device(dev):
127
            model = MyModule(n_layers, dim, dtype, True)
128
        optim = torch.optim.Adam(model.parameters(), foreach=True)
129
        mem_tracker = MemTracker()
130
        mem_tracker.track_external(model, optim)
131
        with mem_tracker as mt:
132
            input_batch = torch.randn(bsz, dim, dim, device=dev, dtype=dtype)
133
            for iter_idx in range(2):
134
                model(input_batch).sum().backward()
135
                optim.step()
136
                optim.zero_grad()
137
                if iter_idx == 0:
138
                    mt.reset_mod_stats()
139

140
        # Check for accuracy of peak memory
141
        tracker_max = mt.get_tracker_snapshot("peak")[dev]["Total"]
142
        mem_stats = torch.cuda.memory_stats(dev)
143
        cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active
144
        accuracy = tracker_max / cuda_max
145
        self.assertAlmostEqual(accuracy, 1.0, delta=0.1)
146

147
    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
148
    def test_tracker_attribution(self):
149
        """
150
        Tests that the tracker correctly categorizes params, gradients, and optimizer states.
151
        """
152
        dev = torch.device(torch.get_default_device())
153
        gc.collect(1)
154
        bsz, n_layers, dim, dtype = 16, 3, 128, torch.float32
155

156
        def get_param_grad_optstate_actual_bytes(
157
            model: nn.Module, opt: torch.optim.Optimizer
158
        ) -> Tuple[int, int, int]:
159
            param_bytes = 0
160
            grad_bytes = 0
161
            opt_state_bytes = 0
162
            for param in model.parameters():
163
                if param.device == dev:
164
                    param_bytes += param.numel() * param.element_size()
165
                if param.grad is not None and param.grad.device == dev:
166
                    grad_bytes += param.grad.numel() * param.grad.element_size()
167

168
            for state in opt.state.values():
169
                for v in state.values():
170
                    if isinstance(v, torch.Tensor) and v.device == dev:
171
                        opt_state_bytes += v.numel() * v.element_size()
172
            return param_bytes, grad_bytes, opt_state_bytes
173

174
        def get_param_grad_optstate_bytes_from_tracker(
175
            tracker: MemTracker,
176
        ) -> Tuple[int, int, int]:
177
            snapshot = tracker.get_tracker_snapshot()
178
            param_bytes = snapshot[dev]["Parameter"]
179
            grad_bytes = snapshot[dev]["Gradient"]
180
            opt_state_bytes = snapshot[dev]["Optstate"]
181
            return param_bytes, grad_bytes, opt_state_bytes
182

183
        def test_attribution_equivalence(
184
            mt: MemTracker,
185
            model: nn.Module,
186
            opt: torch.optim.Optimizer,
187
        ) -> None:
188
            actual = get_param_grad_optstate_actual_bytes(model, opt)
189
            tracker = get_param_grad_optstate_bytes_from_tracker(mt)
190
            for a, b in zip(actual, tracker):
191
                if a == 0:
192
                    self.assertEqual(b, 0)
193
                else:
194
                    self.assertAlmostEqual(b / a, 1.0, delta=0.1)
195

196
        class DummyModel(nn.Module):
197
            def __init__(self, n_layers: int, dim: int, dtype: torch.dtype):
198
                super().__init__()
199
                self.MLP_layers = nn.ModuleList()
200
                for _ in range(n_layers):
201
                    self.MLP_layers.extend(
202
                        [nn.Linear(dim, 2 * dim, dtype=dtype), nn.GELU()]
203
                    )
204
                    self.MLP_layers.extend(
205
                        [nn.Linear(2 * dim, dim, dtype=dtype), nn.GELU()]
206
                    )
207

208
            def forward(self, x):
209
                for layer in self.MLP_layers:
210
                    x = layer(x)
211
                return x
212

213
        with torch.device(dev):
214
            model = DummyModel(n_layers, dim, dtype=dtype)
215
        optim = torch.optim.Adam(model.parameters(), foreach=True)
216
        mem_tracker = MemTracker()
217
        mem_tracker.track_external(model, optim)
218
        with mem_tracker as mt:
219
            input_batch = torch.randn(bsz, dim, device=dev, dtype=dtype)
220
            # Before forward: Only parameters and input are allocated
221
            test_attribution_equivalence(mt, model, optim)
222
            output = model(input_batch)
223
            output.sum().backward()
224
            # After backward: Gradients are allocated
225
            test_attribution_equivalence(mt, model, optim)
226
            output = None
227
            optim.step()
228
            # After step: Optimizer state is allocated
229
            test_attribution_equivalence(mt, model, optim)
230
            optim.zero_grad()
231
            # After zero_grad: Gradients are deallocated
232
            test_attribution_equivalence(mt, model, optim)
233

234

235
if __name__ == "__main__":
236
    run_tests()
237

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.