pytorch

Форк
0
/
test_fsdp_optim_state.py 
138 строк · 4.7 Кб
1
# Owner(s): ["oncall: distributed"]
2

3
import torch
4

5
import torch.distributed.checkpoint as DCP
6
import torch.nn as nn
7
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
8
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
9

10
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
11
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
12
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
13
from torch.testing._internal.common_utils import (
14
    instantiate_parametrized_tests,
15
    parametrize,
16
    run_tests,
17
)
18

19
from torch.testing._internal.distributed._tensor.common_dtensor import (
20
    DTensorTestBase,
21
    with_comms,
22
)
23
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
24

25

26
class FsdpOptimStateCheckpoint(DTensorTestBase):
27
    def _create_model(self):
28
        # make weight tensor dim_0 as large as the world size for scaling test
29
        layer1_weight_dim = self.world_size
30
        layer2_weight_dim = self.world_size * 2
31
        layer3_weight_dim = self.world_size * 3
32

33
        class TestDummyModel(torch.nn.Module):
34
            def __init__(self):
35
                super().__init__()
36
                self.net1 = nn.Sequential(nn.Linear(8, layer1_weight_dim), nn.ReLU())
37
                self.net2 = nn.Sequential(
38
                    nn.Linear(layer1_weight_dim, layer2_weight_dim), nn.ReLU()
39
                )
40
                self.net3 = nn.Sequential(
41
                    nn.Linear(layer2_weight_dim, layer3_weight_dim), nn.ReLU()
42
                )
43

44
            def forward(self, x):
45
                return self.net3(self.net2(self.net1(x)))
46

47
            def get_input(self):
48
                return torch.rand(8, 8, device="cuda")
49

50
        model = TestDummyModel().cuda()
51
        return model
52

53
    @property
54
    def backend(self):
55
        return "cpu:gloo,cuda:nccl"
56

57
    @with_comms
58
    @skip_if_lt_x_gpu(2)
59
    @with_temp_dir
60
    @parametrize("pass_planner", [True, False])
61
    def test_load_sharded_optimizer_state_dict(self, pass_planner) -> None:
62
        CHECKPOINT_DIR = self.temp_dir
63
        planner = DCP.DefaultLoadPlanner() if pass_planner else None
64

65
        model = self._create_model()
66
        model = FSDP(model)
67
        optim = torch.optim.Adam(model.parameters(), lr=0.1)
68

69
        # step ahead to initialize the optimizer
70
        model(model.get_input()).sum().backward()
71
        optim.step()
72

73
        FSDP.set_state_dict_type(
74
            model,
75
            StateDictType.SHARDED_STATE_DICT,
76
        )
77
        optim_osd = FSDP.optim_state_dict(model, optim)
78

79
        state_dict = {
80
            "model": model.state_dict(),
81
            "optim": optim_osd,
82
        }
83
        DCP.save_state_dict(
84
            state_dict=state_dict,
85
            storage_writer=DCP.FileSystemWriter(CHECKPOINT_DIR),
86
        )
87

88
        # now load the model and ensure the values are the same
89
        model_2 = self._create_model()
90
        model_2 = FSDP(model_2)
91
        optim_2 = torch.optim.Adam(model_2.parameters(), lr=0.1)
92

93
        FSDP.set_state_dict_type(
94
            model_2,
95
            StateDictType.SHARDED_STATE_DICT,
96
        )
97
        # Adam lazily creates its state
98
        self.assertEqual(0, len(optim_2.state))
99

100
        state_dict = {
101
            "model": model_2.state_dict(),
102
            # cannot load the optimizer together with the model
103
        }
104
        DCP.load_state_dict(
105
            state_dict=state_dict,
106
            storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),
107
        )
108
        model_2.load_state_dict(state_dict["model"])
109

110
        optim_state = load_sharded_optimizer_state_dict(
111
            model_state_dict=state_dict["model"],
112
            optimizer_key="optim",
113
            storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),
114
            planner=planner,
115
        )
116
        flattened_osd = FSDP.optim_state_dict_to_load(
117
            model_2, optim_2, optim_state["optim"]
118
        )
119
        optim_2.load_state_dict(flattened_osd)
120
        osd_after_load = FSDP.optim_state_dict(model_2, optim_2)
121

122
        # Compare optim_state_dict prior to save and after load
123
        before_optim_state = optim_osd["state"]
124
        after_optim_state = osd_after_load["state"]
125
        self.assertEqual(len(before_optim_state), len(after_optim_state))
126
        for fqn, states in before_optim_state.items():
127
            for state_name, state in states.items():
128
                state2 = after_optim_state.get(fqn).get(state_name)
129
                if isinstance(state, ShardedTensor):
130
                    self.assertTrue(isinstance(state2, ShardedTensor))
131
                    self.assertTrue(torch.allclose(state, state2))
132
                else:
133
                    self.assertEqual(state, state2)
134

135

136
instantiate_parametrized_tests(FsdpOptimStateCheckpoint)
137
if __name__ == "__main__":
138
    run_tests()
139

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.