pytorch

Форк
0
/
test_format_utils.py 
108 строк · 3.2 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import torch
4
import torch.distributed as dist
5
import torch.distributed.checkpoint as dcp
6
import torch.nn as nn
7

8
import torch.nn.functional as F
9
from torch.distributed._tensor.device_mesh import init_device_mesh
10
from torch.distributed.checkpoint.format_utils import (
11
    BroadcastingTorchSaveReader,
12
    dcp_to_torch_save,
13
    DynamicMetaLoadPlanner,
14
    torch_save_to_dcp,
15
)
16
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
17
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
18
from torch.testing._internal.common_utils import run_tests
19
from torch.testing._internal.distributed._tensor.common_dtensor import (
20
    DTensorTestBase,
21
    with_comms,
22
)
23
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
24

25

26
class SimpleModelUneven(nn.Module):
27
    def __init__(self):
28
        super().__init__()
29
        torch.manual_seed(0)
30
        self.net1 = nn.Linear(5, 10)
31
        self.relu = nn.ReLU()
32
        self.net2 = nn.Linear(10, 15)
33
        self.net3 = nn.Linear(15, 30)
34
        self.net4 = nn.Linear(30, 5)
35

36
    def forward(self, x):
37
        x = F.relu(self.net1(x))
38
        x = F.relu(self.net2(x))
39
        x = F.relu(self.net3(x))
40
        x = self.net4(x)
41
        return x
42

43
    def get_input(self):
44
        return torch.rand(4, 5, device="cuda")
45

46

47
class TestFormatUtils(DTensorTestBase):
48
    @with_temp_dir
49
    def test_dcp_to_torch_save(self) -> None:
50
        model = SimpleModelUneven()
51
        dcp.save({"model": model}, checkpoint_id=self.temp_dir)
52

53
        torch_path = self.temp_dir + "/model.pt"
54
        dcp_to_torch_save(self.temp_dir, torch_path)
55

56
        loaded_sd = torch.load(torch_path)
57
        self.assertEqual(loaded_sd, {"model": model.state_dict()})
58

59
    @with_temp_dir
60
    def test_torch_save_to_dcp(self) -> None:
61
        model = SimpleModelUneven()
62
        sd = {"model": model.state_dict()}
63
        torch_path = self.temp_dir + "/model.pt"
64
        torch.save(sd, torch_path)
65

66
        torch_save_to_dcp(torch_path, self.temp_dir)
67

68
        model = SimpleModelUneven()
69
        dcp.load({"model": model}, checkpoint_id=self.temp_dir)
70

71
        self.assertEqual({"model": model.state_dict()}, sd)
72

73
    @with_comms
74
    @with_temp_dir
75
    @skip_if_lt_x_gpu(2)
76
    def test_online_torch_save_to_dcp(self) -> None:
77
        """Tests loading a model saved by torch.save directly into a sharded model
78
        using dcp.load
79
        """
80
        # Save a model with torch.save
81
        model = SimpleModelUneven()
82
        sd = {"model": model.state_dict()}
83

84
        torch_fn = self.temp_dir + "/model.pt"
85
        if dist.get_rank() == 0:
86
            torch.save(sd, torch_fn)
87
        dist.barrier()
88

89
        # Load into a sharded model
90
        device_mesh = init_device_mesh(self.device_type, (self.world_size,))
91
        model = SimpleModelUneven().cuda()
92
        model = FSDP(
93
            model,
94
            device_mesh=device_mesh,
95
            use_orig_params=True,
96
        )
97
        dcp.load(
98
            {"model": model},
99
            planner=DynamicMetaLoadPlanner(),
100
            storage_reader=BroadcastingTorchSaveReader(),
101
            checkpoint_id=torch_fn,
102
        )
103

104
        self.assertEqual(sd["model"], model.state_dict())
105

106

107
if __name__ == "__main__":
108
    run_tests()
109

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.