pytorch
108 строк · 3.2 Кб
1# Owner(s): ["oncall: distributed"]
2
3import torch
4import torch.distributed as dist
5import torch.distributed.checkpoint as dcp
6import torch.nn as nn
7
8import torch.nn.functional as F
9from torch.distributed._tensor.device_mesh import init_device_mesh
10from torch.distributed.checkpoint.format_utils import (
11BroadcastingTorchSaveReader,
12dcp_to_torch_save,
13DynamicMetaLoadPlanner,
14torch_save_to_dcp,
15)
16from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
17from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
18from torch.testing._internal.common_utils import run_tests
19from torch.testing._internal.distributed._tensor.common_dtensor import (
20DTensorTestBase,
21with_comms,
22)
23from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
24
25
26class SimpleModelUneven(nn.Module):
27def __init__(self):
28super().__init__()
29torch.manual_seed(0)
30self.net1 = nn.Linear(5, 10)
31self.relu = nn.ReLU()
32self.net2 = nn.Linear(10, 15)
33self.net3 = nn.Linear(15, 30)
34self.net4 = nn.Linear(30, 5)
35
36def forward(self, x):
37x = F.relu(self.net1(x))
38x = F.relu(self.net2(x))
39x = F.relu(self.net3(x))
40x = self.net4(x)
41return x
42
43def get_input(self):
44return torch.rand(4, 5, device="cuda")
45
46
47class TestFormatUtils(DTensorTestBase):
48@with_temp_dir
49def test_dcp_to_torch_save(self) -> None:
50model = SimpleModelUneven()
51dcp.save({"model": model}, checkpoint_id=self.temp_dir)
52
53torch_path = self.temp_dir + "/model.pt"
54dcp_to_torch_save(self.temp_dir, torch_path)
55
56loaded_sd = torch.load(torch_path)
57self.assertEqual(loaded_sd, {"model": model.state_dict()})
58
59@with_temp_dir
60def test_torch_save_to_dcp(self) -> None:
61model = SimpleModelUneven()
62sd = {"model": model.state_dict()}
63torch_path = self.temp_dir + "/model.pt"
64torch.save(sd, torch_path)
65
66torch_save_to_dcp(torch_path, self.temp_dir)
67
68model = SimpleModelUneven()
69dcp.load({"model": model}, checkpoint_id=self.temp_dir)
70
71self.assertEqual({"model": model.state_dict()}, sd)
72
73@with_comms
74@with_temp_dir
75@skip_if_lt_x_gpu(2)
76def test_online_torch_save_to_dcp(self) -> None:
77"""Tests loading a model saved by torch.save directly into a sharded model
78using dcp.load
79"""
80# Save a model with torch.save
81model = SimpleModelUneven()
82sd = {"model": model.state_dict()}
83
84torch_fn = self.temp_dir + "/model.pt"
85if dist.get_rank() == 0:
86torch.save(sd, torch_fn)
87dist.barrier()
88
89# Load into a sharded model
90device_mesh = init_device_mesh(self.device_type, (self.world_size,))
91model = SimpleModelUneven().cuda()
92model = FSDP(
93model,
94device_mesh=device_mesh,
95use_orig_params=True,
96)
97dcp.load(
98{"model": model},
99planner=DynamicMetaLoadPlanner(),
100storage_reader=BroadcastingTorchSaveReader(),
101checkpoint_id=torch_fn,
102)
103
104self.assertEqual(sd["model"], model.state_dict())
105
106
107if __name__ == "__main__":
108run_tests()
109