pytorch
234 строки · 8.5 Кб
1# Owner(s): ["oncall: distributed"]
2
3import sys4
5import torch6import torch.nn as nn7import torch.optim as optim8from torch import distributed as dist9from torch.distributed.fsdp import FullyShardedDataParallel as FSDP10from torch.testing._internal.common_distributed import skip_if_lt_x_gpu11from torch.testing._internal.common_fsdp import FSDPTest12from torch.testing._internal.common_utils import (13instantiate_parametrized_tests,14parametrize,15run_tests,16TEST_WITH_DEV_DBG_ASAN,17)
18from torch.utils.checkpoint import checkpoint19
20if not dist.is_available():21print("Distributed not available, skipping tests", file=sys.stderr)22sys.exit(0)23
24if TEST_WITH_DEV_DBG_ASAN:25print(26"Skip dev-asan as torch + multiprocessing spawn have known issues",27file=sys.stderr,28)29sys.exit(0)30
31
32def get_cur_mem(rank, result, prefix):33"""Collect memory allocated values in a result dict in MB"""34torch._C._cuda_clearCublasWorkspaces()35result[prefix] = round(torch.cuda.memory_allocated() / 1024 / 1024)36
37
38class Model(nn.Module):39def __init__(self, hidden_dim, with_fsdp=False, with_checkpoint=False):40super().__init__()41if with_fsdp:42self.stem = nn.Sequential(43nn.Conv2d(3, 64, kernel_size=3),44FSDP(nn.BatchNorm2d(64)),45nn.ReLU(inplace=True),46)47else:48self.stem = nn.Sequential(49nn.Conv2d(3, 64, kernel_size=3),50nn.BatchNorm2d(64),51nn.ReLU(inplace=True),52)53if with_fsdp:54self.blocks = nn.Sequential(55nn.Conv2d(64, hidden_dim, kernel_size=5, padding=2),56FSDP(nn.BatchNorm2d(hidden_dim)),57nn.ReLU(inplace=True),58nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),59FSDP(nn.BatchNorm2d(hidden_dim)),60nn.ReLU(inplace=True),61nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),62FSDP(nn.BatchNorm2d(hidden_dim)),63nn.ReLU(inplace=True),64nn.AdaptiveAvgPool2d(output_size=(1, 1)),65nn.Flatten(),66)67else:68self.blocks = nn.Sequential(69nn.Conv2d(64, hidden_dim, kernel_size=5, padding=2),70nn.BatchNorm2d(hidden_dim),71nn.ReLU(inplace=True),72nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),73nn.BatchNorm2d(hidden_dim),74nn.ReLU(inplace=True),75nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),76nn.BatchNorm2d(hidden_dim),77nn.ReLU(inplace=True),78nn.AdaptiveAvgPool2d(output_size=(1, 1)),79nn.Flatten(),80)81
82self.head = nn.Linear(hidden_dim, 10)83self.with_checkpoint = with_checkpoint84
85def forward(self, x):86if self.with_checkpoint:87return self.head(checkpoint(self.blocks, self.stem(x), use_reentrant=True))88else:89return self.head(self.blocks(self.stem(x)))90
91
92def create_model(with_fsdp, with_checkpoint, model_hidden_dim):93torch.manual_seed(0)94model = Model(model_hidden_dim, with_fsdp, with_checkpoint)95if with_fsdp:96model.stem = FSDP(model.stem)97model.blocks = FSDP(model.blocks)98model.head = FSDP(model.head)99
100return model101
102
103class TestFSDPMemory(FSDPTest):104@property105def world_size(self):106return 2107
108def _dist_train(self, with_checkpoint, expected, model_hidden_dim, iterations):109gpu_id = self.rank110world_size = self.world_size111
112batch = torch.randn(size=(2, 3, 224, 224)).cuda()113
114model = create_model(115with_fsdp=True,116with_checkpoint=with_checkpoint,117model_hidden_dim=model_hidden_dim,118)119model = model.cuda()120model = FSDP(model)121
122# We enable momentum so that after the first iteration, the optimizer state is added123# to the total memory used.124criterion = nn.MSELoss()125optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)126
127results = {} # results of memory stats128for iteration in range(iterations):129get_cur_mem(gpu_id, results, f"iter {iteration}: start")130
131out = model(batch)132get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd")133
134out = sum(o.sum() for o in out[0])135fake_loss = criterion(out, torch.tensor(0.0).cuda())136get_cur_mem(gpu_id, results, f"iter {iteration}: after loss")137
138fake_loss.backward()139get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd")140
141optimizer.step()142get_cur_mem(gpu_id, results, f"iter {iteration}: after step")143
144# It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory.145model.zero_grad(set_to_none=True)146get_cur_mem(gpu_id, results, f"iter {iteration}: done")147
148def cmp(results, expected):149ret = ""150self.assertEqual(results.keys(), expected.keys())151for k, v in results.items():152exp = expected[k]153if abs(exp - v) > 1: # allow 1MB rounding differences154ret += f"{k}: got {v}, expected {exp}\n"155return ret156
157output = cmp(results, expected)158self.assertEqual(output, "")159
160@skip_if_lt_x_gpu(2)161@parametrize("ckpt", ["no_ckpt", "ckpt"])162def test_fsdp_memory(self, ckpt):163# hidden_dim 128: model size ~4MB164model_hidden_dim = 128165
166model = create_model(167with_fsdp=False, with_checkpoint=False, model_hidden_dim=model_hidden_dim168).cuda()169model_size_mb = round(torch.cuda.memory_allocated() / 1024 / 1024)170del model171
172sharded_model_size_mb = int(model_size_mb / self.world_size)173
174# We have observed that sometimes after 3rd iteration, 4th one can fail (not on this175# test but on much bigger scale tests). We run 4 iterations here just in case it happens.176iterations = 4177
178expected = {}179
180for iteration in range(iterations):181if iteration == 0:182# sharded model size + 1MB temp memory183expected[f"iter {iteration}: start"] = sharded_model_size_mb + 1184# it is hard to calculate this memory size, get it from printed memory usage185if ckpt == "ckpt":186expected[f"iter {iteration}: after fwd"] = 51187expected[f"iter {iteration}: after loss"] = 51188else:189expected[f"iter {iteration}: after fwd"] = 340190expected[f"iter {iteration}: after loss"] = 340191# sharded model size + sharded grad size + 1M temp memory192expected[f"iter {iteration}: after bwd"] = 2 * sharded_model_size_mb + 1193else:194# after optimizer step in the first iteration, memory usage increased by195# sharded_model_size_mb because of increased optimizer states memory usage196expected[f"iter {iteration}: start"] = 2 * sharded_model_size_mb + 1197if ckpt == "ckpt":198expected[f"iter {iteration}: after fwd"] = (19951 + sharded_model_size_mb200)201expected[f"iter {iteration}: after loss"] = (20251 + sharded_model_size_mb203)204else:205expected[f"iter {iteration}: after fwd"] = (206340 + sharded_model_size_mb207)208expected[f"iter {iteration}: after loss"] = (209340 + sharded_model_size_mb210)211expected[f"iter {iteration}: after bwd"] = 3 * sharded_model_size_mb + 1212
213# sharded model size + sharded grad size + optimizer states + 1M temp memory214expected[f"iter {iteration}: after step"] = 3 * sharded_model_size_mb + 1215# grad memory is claimed after setting grad = None216# sharded model size + optimizer states + 1M temp memory217expected[f"iter {iteration}: done"] = 2 * sharded_model_size_mb + 1218
219# Get the fsdp and checkpoint flags.220with_ckpt = ckpt == "ckpt"221
222self._dist_train(223with_ckpt,224expected,225model_hidden_dim,226iterations,227)228
229
230instantiate_parametrized_tests(TestFSDPMemory)231
232
233if __name__ == "__main__":234run_tests()235