pytorch
270 строк · 9.3 Кб
1# Owner(s): ["module: unknown"]
2import functools3import gc4from typing import Union5
6import torch7import torch.nn as nn8from torch.distributed._composable import checkpoint9from torch.distributed._composable.fsdp import (10CPUOffloadPolicy,11fully_shard,12MixedPrecisionPolicy,13OffloadPolicy,14)
15from torch.distributed._tensor import init_device_mesh16from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker17from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (18apply_activation_checkpointing,19CheckpointWrapper,20)
21from torch.testing._internal.common_distributed import skip_if_lt_x_gpu22from torch.testing._internal.common_fsdp import FSDPTest, MLP23from torch.testing._internal.common_utils import run_tests24from torch.testing._internal.distributed._tensor.common_dtensor import (25ModelArgs,26Transformer,27TransformerBlock,28)
29
30
31def _init_cublas_workspace(dev: torch.device):32lin = torch.nn.Linear(768, 768, device=dev)33inp = torch.randn(1, 768, device=dev)34lin(inp).sum().backward()35del lin36del inp37
38
39def _reset_mem_stats(dev: torch.device):40torch.cuda.empty_cache()41torch.cuda.reset_accumulated_memory_stats(dev)42torch.cuda.reset_peak_memory_stats(dev)43
44
45class TestTrackerFullyShard1DTrainingCore(FSDPTest):46@property47def world_size(self) -> int:48return min(4, torch.cuda.device_count())49
50@skip_if_lt_x_gpu(2)51def test_tracker_multi_group_eager(self):52"""53Tests tracker accuracy when using multiple parameter groups for
54communication (for communication and computation overlap plus memory
55reduction) and different mixed precision policies.
56"""
57self.run_subtests(58{59"reshard_after_forward": [True, False],60"offload_policy": [61CPUOffloadPolicy(pin_memory=False),62OffloadPolicy(),63],64"mp_policy": [65MixedPrecisionPolicy(66param_dtype=torch.float16, reduce_dtype=torch.float3267),68],69},70self._test_tracker_multi_group,71)72
73def _test_tracker_multi_group(74self,75reshard_after_forward: Union[bool, int],76offload_policy: OffloadPolicy,77mp_policy: MixedPrecisionPolicy,78):79debug = False80dev = torch.device(torch.cuda.current_device())81_init_cublas_workspace(dev)82gc.collect()83_reset_mem_stats(dev)84mem_stats = torch.cuda.memory_stats(dev)85pre_cuda_active = mem_stats["active_bytes.all.current"]86torch.manual_seed(42)87lin_dim, bsz = 2048, 819288with torch.device(dev):89model = nn.Sequential(*[MLP(dim=lin_dim, device=dev) for _ in range(4)])90mesh = init_device_mesh("cuda", (self.world_size,))91fully_shard_fn = functools.partial(92fully_shard,93mesh=mesh,94reshard_after_forward=reshard_after_forward,95offload_policy=offload_policy,96mp_policy=mp_policy,97)98for mlp in model:99fully_shard_fn(mlp)100fully_shard_fn(model)101optim = torch.optim.Adam(model.parameters(), lr=1e-2)102inp = torch.randn((bsz, lin_dim), device=dev)103fmt = FSDPMemTracker(model, optim)104fmt.track_inputs((inp,))105with fmt:106for iter_idx in range(2):107loss = model(inp).sum()108loss.backward()109optim.step()110optim.zero_grad()111if iter_idx == 0:112fmt.reset_mod_stats()113mem_stats = torch.cuda.memory_stats()114tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"]115cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active116accuracy = tracker_max / cuda_max117if self.rank == 0 and debug:118print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}")119self.assertAlmostEqual(120accuracy,1211.0,122delta=0.1,123msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}",124)125del model126del inp127del optim128
129@skip_if_lt_x_gpu(2)130def test_tracker_non_root_forward_backward(self):131"""132Tests tracker accracy when running forward/backward through a non-root.
133"""
134debug = False135dev = torch.device(torch.cuda.current_device())136_init_cublas_workspace(dev)137gc.collect()138_reset_mem_stats(dev)139mem_stats = torch.cuda.memory_stats(dev)140pre_cuda_active = mem_stats["active_bytes.all.current"]141torch.manual_seed(42)142lin_dim, bsz = 2048, 8143model = nn.Sequential(*[MLP(lin_dim, dev) for _ in range(3)])144for mlp in model:145fully_shard(mlp)146fully_shard(model)147optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)148torch.manual_seed(42 + self.rank)149inp = torch.randn((bsz, lin_dim), device=dev)150fmt = FSDPMemTracker(model, optim)151fmt.track_inputs((inp,))152with fmt:153for iter_idx in range(2):154nonroot_loss = model[0](inp).sum()155nonroot_loss.backward()156optim.step()157optim.zero_grad()158if iter_idx == 0:159fmt.reset_mod_stats()160mem_stats = torch.cuda.memory_stats()161tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"]162cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active163accuracy = tracker_max / cuda_max164if self.rank == 0 and debug:165print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}")166self.assertAlmostEqual(167accuracy,1681.0,169delta=0.1,170msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}",171)172del inp173del model174del optim175
176
177class TestTrackerFullyShard1DTrainingCompose(FSDPTest):178@property179def world_size(self) -> int:180return min(torch.cuda.device_count(), 4)181
182@skip_if_lt_x_gpu(2)183def test_tracker_with_activation_checkpointing(self):184"""185Tests tracker accuracy when composing with activation checkpointing.
186"""
187self.run_subtests(188{189"reshard_after_forward": [True, False],190"checkpoint_impl": ["composable", "wrapper"],191},192self._test_tracker_with_activation_checkpointing,193)194
195def _test_tracker_with_activation_checkpointing(196self, reshard_after_forward: Union[bool, int], checkpoint_impl: str197):198assert checkpoint_impl in ("composable", "wrapper")199debug = False200dev = torch.device(torch.cuda.current_device())201_init_cublas_workspace(dev)202gc.collect()203_reset_mem_stats(dev)204mem_stats = torch.cuda.memory_stats(dev)205pre_cuda_active = mem_stats["active_bytes.all.current"]206torch.manual_seed(42)207vocab_size = 8192208bsz, seq_len = 16, 512209with torch.device(dev):210model_args = ModelArgs(211n_layers=4,212n_heads=4,213vocab_size=vocab_size,214max_seq_len=seq_len,215dropout_p=0.1,216)217model = Transformer(model_args)218foreach = False219fully_shard_fn = functools.partial(220fully_shard,221reshard_after_forward=reshard_after_forward,222)223if checkpoint_impl == "wrapper":224apply_activation_checkpointing(225model, check_fn=lambda m: isinstance(m, TransformerBlock)226)227for module in model.modules():228# Apply to `CheckpointWrapper`, which wraps `TransformerBlock`229if isinstance(module, CheckpointWrapper):230fully_shard_fn(module)231else:232for module in model.modules():233if isinstance(module, TransformerBlock):234if checkpoint_impl == "composable":235checkpoint(module)236fully_shard_fn(module)237fully_shard_fn(model)238optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=foreach)239
240torch.manual_seed(42 + self.rank)241inp = torch.randint(0, vocab_size, (bsz, seq_len), device=dev)242fmt = FSDPMemTracker(model, optim)243fmt.track_inputs((inp,))244with fmt:245for iter_idx in range(2):246loss = model(inp).sum()247loss.backward()248optim.step()249optim.zero_grad()250if iter_idx == 0:251fmt.reset_mod_stats()252mem_stats = torch.cuda.memory_stats()253tracker_max = fmt.get_tracker_snapshot("peak")[dev]["Total"]254cuda_max = mem_stats["active_bytes.all.peak"] - pre_cuda_active255accuracy = tracker_max / cuda_max256if self.rank == 0 and debug:257print(f"Accuracy: {accuracy} Tracker Max:{tracker_max} CUDA Max:{cuda_max}")258self.assertAlmostEqual(259accuracy,2601.0,261delta=0.1,262msg=f"Tracker Max:{tracker_max} CUDA Max:{cuda_max}",263)264del inp265del model266del optim267
268
269if __name__ == "__main__":270run_tests()271