pytorch
236 строк · 9.2 Кб
1# Owner(s): ["module: unknown"]
2import gc3import unittest4from typing import Tuple5
6import torch7import torch.nn as nn8from torch.distributed._tools.mem_tracker import MemTracker9from torch.testing._internal.common_cuda import TEST_CUDA10from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase11from torch.utils.checkpoint import checkpoint12
13
14class TestMemTracker(TestCase):15def _init_cublas_workspace(self, dev: torch.device):16lin = torch.nn.Linear(768, 768, device=dev)17inp = torch.randn(1, 768, device=dev)18lin(inp).sum().backward()19del lin20del inp21
22def _reset_mem_stats(self, dev: torch.device):23torch.cuda.empty_cache()24torch.cuda.reset_accumulated_memory_stats(dev)25torch.cuda.reset_peak_memory_stats(dev)26
27@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")28@unittest.skipIf(not TEST_CUDA, "CUDA not available")29def test_cuda_tracker_equivalence(30self,31):32"""33Tests that the tracker correctly calculates the peak memory.
34"""
35dev = torch.device(torch.cuda.current_device())36self._init_cublas_workspace(dev)37gc.collect(1)38self._reset_mem_stats(dev)39mem_stats = torch.cuda.memory_stats(dev)40pre_cuda_active = mem_stats["active_bytes.all.current"]41bsz, n_layers, dim, dtype = 16, 4, 512, torch.bfloat1642
43class DummyModel(nn.Module):44def __init__(self, n_layers: int, dim: int, dtype: torch.dtype):45super().__init__()46self.linears = nn.ModuleList()47for _ in range(n_layers):48self.linears.append(nn.Linear(dim, dim, dtype=dtype))49self.linears.append(nn.ReLU())50
51def forward(self, x):52for layer in self.linears:53x = layer(x)54return x55
56with torch.device(dev):57model = DummyModel(n_layers, dim, dtype=dtype)58optim = torch.optim.Adam(model.parameters(), foreach=True)59input_batch = torch.randn(bsz, dim, device=dev, dtype=dtype)60mem_tracker = MemTracker()61mem_tracker.track_external(model, optim, input_batch)62with mem_tracker as mt:63for iter_idx in range(2):64model(input_batch).sum().backward()65optim.step()66optim.zero_grad()67if iter_idx == 0:68mt.reset_mod_stats()69# Check for accuracy of peak memory70
71tracker_max = mt.get_tracker_snapshot("peak")[dev]["Total"]72mem_stats = torch.cuda.memory_stats(dev)73cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active74accuracy = tracker_max / cuda_max75self.assertAlmostEqual(accuracy, 1.0, delta=0.1)76
77@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")78@unittest.skipIf(not TEST_CUDA, "CUDA not available")79def test_tracker_with_activation_checkpointing(80self,81):82"""83Tests that the tracker correctly computes the peak memory during activation checkpointing.
84"""
85dev = torch.device(torch.cuda.current_device())86self._init_cublas_workspace(dev)87gc.collect(1)88self._reset_mem_stats(dev)89mem_stats = torch.cuda.memory_stats(dev)90pre_cuda_active = mem_stats["active_bytes.all.current"]91
92bsz, n_layers, dim, dtype = 128, 4, 1024, torch.float1693
94class MLPBlock(nn.Module):95def __init__(self, dim: int, dtype: torch.dtype):96super().__init__()97self.mlp_block = nn.Sequential(98nn.Linear(dim, 2 * dim, dtype=dtype),99nn.ReLU(),100nn.Linear(2 * dim, dim, dtype=dtype),101)102
103def forward(self, x):104return self.mlp_block(x)105
106class MyModule(nn.Module):107def __init__(108self, n_layers: int, dim: int, dtype: torch.dtype, use_ac: bool = False109):110super().__init__()111self.mlp_blocks = nn.ModuleList()112self.use_ac = use_ac113for _ in range(n_layers):114self.mlp_blocks.append(MLPBlock(dim, dtype=dtype))115
116def forward(self, x):117for i, block in enumerate(self.mlp_blocks):118if i >= 1 and self.use_ac:119x = checkpoint(120block, x, preserve_rng_state=True, use_reentrant=False121)122else:123x = block(x)124return x125
126with torch.device(dev):127model = MyModule(n_layers, dim, dtype, True)128optim = torch.optim.Adam(model.parameters(), foreach=True)129mem_tracker = MemTracker()130mem_tracker.track_external(model, optim)131with mem_tracker as mt:132input_batch = torch.randn(bsz, dim, dim, device=dev, dtype=dtype)133for iter_idx in range(2):134model(input_batch).sum().backward()135optim.step()136optim.zero_grad()137if iter_idx == 0:138mt.reset_mod_stats()139
140# Check for accuracy of peak memory141tracker_max = mt.get_tracker_snapshot("peak")[dev]["Total"]142mem_stats = torch.cuda.memory_stats(dev)143cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active144accuracy = tracker_max / cuda_max145self.assertAlmostEqual(accuracy, 1.0, delta=0.1)146
147@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")148def test_tracker_attribution(self):149"""150Tests that the tracker correctly categorizes params, gradients, and optimizer states.
151"""
152dev = torch.device(torch.get_default_device())153gc.collect(1)154bsz, n_layers, dim, dtype = 16, 3, 128, torch.float32155
156def get_param_grad_optstate_actual_bytes(157model: nn.Module, opt: torch.optim.Optimizer158) -> Tuple[int, int, int]:159param_bytes = 0160grad_bytes = 0161opt_state_bytes = 0162for param in model.parameters():163if param.device == dev:164param_bytes += param.numel() * param.element_size()165if param.grad is not None and param.grad.device == dev:166grad_bytes += param.grad.numel() * param.grad.element_size()167
168for state in opt.state.values():169for v in state.values():170if isinstance(v, torch.Tensor) and v.device == dev:171opt_state_bytes += v.numel() * v.element_size()172return param_bytes, grad_bytes, opt_state_bytes173
174def get_param_grad_optstate_bytes_from_tracker(175tracker: MemTracker,176) -> Tuple[int, int, int]:177snapshot = tracker.get_tracker_snapshot()178param_bytes = snapshot[dev]["Parameter"]179grad_bytes = snapshot[dev]["Gradient"]180opt_state_bytes = snapshot[dev]["Optstate"]181return param_bytes, grad_bytes, opt_state_bytes182
183def test_attribution_equivalence(184mt: MemTracker,185model: nn.Module,186opt: torch.optim.Optimizer,187) -> None:188actual = get_param_grad_optstate_actual_bytes(model, opt)189tracker = get_param_grad_optstate_bytes_from_tracker(mt)190for a, b in zip(actual, tracker):191if a == 0:192self.assertEqual(b, 0)193else:194self.assertAlmostEqual(b / a, 1.0, delta=0.1)195
196class DummyModel(nn.Module):197def __init__(self, n_layers: int, dim: int, dtype: torch.dtype):198super().__init__()199self.MLP_layers = nn.ModuleList()200for _ in range(n_layers):201self.MLP_layers.extend(202[nn.Linear(dim, 2 * dim, dtype=dtype), nn.GELU()]203)204self.MLP_layers.extend(205[nn.Linear(2 * dim, dim, dtype=dtype), nn.GELU()]206)207
208def forward(self, x):209for layer in self.MLP_layers:210x = layer(x)211return x212
213with torch.device(dev):214model = DummyModel(n_layers, dim, dtype=dtype)215optim = torch.optim.Adam(model.parameters(), foreach=True)216mem_tracker = MemTracker()217mem_tracker.track_external(model, optim)218with mem_tracker as mt:219input_batch = torch.randn(bsz, dim, device=dev, dtype=dtype)220# Before forward: Only parameters and input are allocated221test_attribution_equivalence(mt, model, optim)222output = model(input_batch)223output.sum().backward()224# After backward: Gradients are allocated225test_attribution_equivalence(mt, model, optim)226output = None227optim.step()228# After step: Optimizer state is allocated229test_attribution_equivalence(mt, model, optim)230optim.zero_grad()231# After zero_grad: Gradients are deallocated232test_attribution_equivalence(mt, model, optim)233
234
235if __name__ == "__main__":236run_tests()237