pytorch

Форк
0
/
test_fsdp_memory.py 
234 строки · 8.5 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import sys
4

5
import torch
6
import torch.nn as nn
7
import torch.optim as optim
8
from torch import distributed as dist
9
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
10
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
11
from torch.testing._internal.common_fsdp import FSDPTest
12
from torch.testing._internal.common_utils import (
13
    instantiate_parametrized_tests,
14
    parametrize,
15
    run_tests,
16
    TEST_WITH_DEV_DBG_ASAN,
17
)
18
from torch.utils.checkpoint import checkpoint
19

20
if not dist.is_available():
21
    print("Distributed not available, skipping tests", file=sys.stderr)
22
    sys.exit(0)
23

24
if TEST_WITH_DEV_DBG_ASAN:
25
    print(
26
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
27
        file=sys.stderr,
28
    )
29
    sys.exit(0)
30

31

32
def get_cur_mem(rank, result, prefix):
33
    """Collect memory allocated values in a result dict in MB"""
34
    torch._C._cuda_clearCublasWorkspaces()
35
    result[prefix] = round(torch.cuda.memory_allocated() / 1024 / 1024)
36

37

38
class Model(nn.Module):
39
    def __init__(self, hidden_dim, with_fsdp=False, with_checkpoint=False):
40
        super().__init__()
41
        if with_fsdp:
42
            self.stem = nn.Sequential(
43
                nn.Conv2d(3, 64, kernel_size=3),
44
                FSDP(nn.BatchNorm2d(64)),
45
                nn.ReLU(inplace=True),
46
            )
47
        else:
48
            self.stem = nn.Sequential(
49
                nn.Conv2d(3, 64, kernel_size=3),
50
                nn.BatchNorm2d(64),
51
                nn.ReLU(inplace=True),
52
            )
53
        if with_fsdp:
54
            self.blocks = nn.Sequential(
55
                nn.Conv2d(64, hidden_dim, kernel_size=5, padding=2),
56
                FSDP(nn.BatchNorm2d(hidden_dim)),
57
                nn.ReLU(inplace=True),
58
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
59
                FSDP(nn.BatchNorm2d(hidden_dim)),
60
                nn.ReLU(inplace=True),
61
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
62
                FSDP(nn.BatchNorm2d(hidden_dim)),
63
                nn.ReLU(inplace=True),
64
                nn.AdaptiveAvgPool2d(output_size=(1, 1)),
65
                nn.Flatten(),
66
            )
67
        else:
68
            self.blocks = nn.Sequential(
69
                nn.Conv2d(64, hidden_dim, kernel_size=5, padding=2),
70
                nn.BatchNorm2d(hidden_dim),
71
                nn.ReLU(inplace=True),
72
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
73
                nn.BatchNorm2d(hidden_dim),
74
                nn.ReLU(inplace=True),
75
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
76
                nn.BatchNorm2d(hidden_dim),
77
                nn.ReLU(inplace=True),
78
                nn.AdaptiveAvgPool2d(output_size=(1, 1)),
79
                nn.Flatten(),
80
            )
81

82
        self.head = nn.Linear(hidden_dim, 10)
83
        self.with_checkpoint = with_checkpoint
84

85
    def forward(self, x):
86
        if self.with_checkpoint:
87
            return self.head(checkpoint(self.blocks, self.stem(x), use_reentrant=True))
88
        else:
89
            return self.head(self.blocks(self.stem(x)))
90

91

92
def create_model(with_fsdp, with_checkpoint, model_hidden_dim):
93
    torch.manual_seed(0)
94
    model = Model(model_hidden_dim, with_fsdp, with_checkpoint)
95
    if with_fsdp:
96
        model.stem = FSDP(model.stem)
97
        model.blocks = FSDP(model.blocks)
98
        model.head = FSDP(model.head)
99

100
    return model
101

102

103
class TestFSDPMemory(FSDPTest):
104
    @property
105
    def world_size(self):
106
        return 2
107

108
    def _dist_train(self, with_checkpoint, expected, model_hidden_dim, iterations):
109
        gpu_id = self.rank
110
        world_size = self.world_size
111

112
        batch = torch.randn(size=(2, 3, 224, 224)).cuda()
113

114
        model = create_model(
115
            with_fsdp=True,
116
            with_checkpoint=with_checkpoint,
117
            model_hidden_dim=model_hidden_dim,
118
        )
119
        model = model.cuda()
120
        model = FSDP(model)
121

122
        # We enable momentum so that after the first iteration, the optimizer state is added
123
        # to the total memory used.
124
        criterion = nn.MSELoss()
125
        optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
126

127
        results = {}  # results of memory stats
128
        for iteration in range(iterations):
129
            get_cur_mem(gpu_id, results, f"iter {iteration}: start")
130

131
            out = model(batch)
132
            get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd")
133

134
            out = sum(o.sum() for o in out[0])
135
            fake_loss = criterion(out, torch.tensor(0.0).cuda())
136
            get_cur_mem(gpu_id, results, f"iter {iteration}: after loss")
137

138
            fake_loss.backward()
139
            get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd")
140

141
            optimizer.step()
142
            get_cur_mem(gpu_id, results, f"iter {iteration}: after step")
143

144
            # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory.
145
            model.zero_grad(set_to_none=True)
146
            get_cur_mem(gpu_id, results, f"iter {iteration}: done")
147

148
        def cmp(results, expected):
149
            ret = ""
150
            self.assertEqual(results.keys(), expected.keys())
151
            for k, v in results.items():
152
                exp = expected[k]
153
                if abs(exp - v) > 1:  # allow 1MB rounding differences
154
                    ret += f"{k}: got {v}, expected {exp}\n"
155
            return ret
156

157
        output = cmp(results, expected)
158
        self.assertEqual(output, "")
159

160
    @skip_if_lt_x_gpu(2)
161
    @parametrize("ckpt", ["no_ckpt", "ckpt"])
162
    def test_fsdp_memory(self, ckpt):
163
        # hidden_dim 128: model size ~4MB
164
        model_hidden_dim = 128
165

166
        model = create_model(
167
            with_fsdp=False, with_checkpoint=False, model_hidden_dim=model_hidden_dim
168
        ).cuda()
169
        model_size_mb = round(torch.cuda.memory_allocated() / 1024 / 1024)
170
        del model
171

172
        sharded_model_size_mb = int(model_size_mb / self.world_size)
173

174
        # We have observed that sometimes after 3rd iteration, 4th one can fail (not on this
175
        # test but on much bigger scale tests). We run 4 iterations here just in case it happens.
176
        iterations = 4
177

178
        expected = {}
179

180
        for iteration in range(iterations):
181
            if iteration == 0:
182
                # sharded model size + 1MB temp memory
183
                expected[f"iter {iteration}: start"] = sharded_model_size_mb + 1
184
                # it is hard to calculate this memory size, get it from printed memory usage
185
                if ckpt == "ckpt":
186
                    expected[f"iter {iteration}: after fwd"] = 51
187
                    expected[f"iter {iteration}: after loss"] = 51
188
                else:
189
                    expected[f"iter {iteration}: after fwd"] = 340
190
                    expected[f"iter {iteration}: after loss"] = 340
191
                # sharded model size + sharded grad size + 1M temp memory
192
                expected[f"iter {iteration}: after bwd"] = 2 * sharded_model_size_mb + 1
193
            else:
194
                # after optimizer step in the first iteration, memory usage increased by
195
                # sharded_model_size_mb because of increased optimizer states memory usage
196
                expected[f"iter {iteration}: start"] = 2 * sharded_model_size_mb + 1
197
                if ckpt == "ckpt":
198
                    expected[f"iter {iteration}: after fwd"] = (
199
                        51 + sharded_model_size_mb
200
                    )
201
                    expected[f"iter {iteration}: after loss"] = (
202
                        51 + sharded_model_size_mb
203
                    )
204
                else:
205
                    expected[f"iter {iteration}: after fwd"] = (
206
                        340 + sharded_model_size_mb
207
                    )
208
                    expected[f"iter {iteration}: after loss"] = (
209
                        340 + sharded_model_size_mb
210
                    )
211
                expected[f"iter {iteration}: after bwd"] = 3 * sharded_model_size_mb + 1
212

213
            # sharded model size + sharded grad size + optimizer states + 1M temp memory
214
            expected[f"iter {iteration}: after step"] = 3 * sharded_model_size_mb + 1
215
            # grad memory is claimed after setting grad = None
216
            # sharded model size + optimizer states + 1M temp memory
217
            expected[f"iter {iteration}: done"] = 2 * sharded_model_size_mb + 1
218

219
        # Get the fsdp and checkpoint flags.
220
        with_ckpt = ckpt == "ckpt"
221

222
        self._dist_train(
223
            with_ckpt,
224
            expected,
225
            model_hidden_dim,
226
            iterations,
227
        )
228

229

230
instantiate_parametrized_tests(TestFSDPMemory)
231

232

233
if __name__ == "__main__":
234
    run_tests()
235

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.