pytorch

Форк
0
/
_state_dict_utils.py 
210 строк · 7.3 Кб
1
import math
2
from typing import Any, Callable, Dict, Optional, Tuple, TYPE_CHECKING
3

4
import torch
5
import torch.distributed as dist
6
import torch.nn.functional as F
7
from torch.distributed._functional_collectives import AsyncCollectiveTensor
8

9
if dist.is_available() or TYPE_CHECKING:
10
    from torch.distributed import distributed_c10d
11
    from torch.distributed._shard.sharded_tensor import ShardedTensor
12
    from torch.distributed._tensor import DTensor, Replicate
13

14

15
def _all_gather_sharded_tensor(
16
    sharded_tensor: "ShardedTensor",
17
    pg: Optional[dist.ProcessGroup] = None,
18
    device: Optional[torch.device] = None,
19
) -> torch.Tensor:
20
    if pg is None:
21
        pg = distributed_c10d._get_default_group()
22
    world_size = dist.get_world_size(pg)
23
    shards = sharded_tensor.local_shards()
24
    dim_0_size = sharded_tensor.size()[0]  # type: ignore[index]
25
    tensor_numel = sharded_tensor.size().numel()  # type: ignore[union-attr]
26
    chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size
27
    pg_device = (
28
        distributed_c10d._get_pg_default_device(pg) if device is None else device
29
    )
30
    if shards:
31
        local_tensor = shards[0].tensor.flatten()
32
        if local_tensor.device.type != pg_device.type:
33
            local_tensor = local_tensor.to(pg_device)
34
        num_padding = chunk_size - local_tensor.numel()
35
        if num_padding > 0:
36
            local_tensor = F.pad(local_tensor, [0, num_padding])
37
    else:
38
        local_tensor = torch.zeros(
39
            chunk_size, dtype=sharded_tensor.dtype, device=pg_device
40
        )
41

42
    tensor = torch.empty(
43
        chunk_size * world_size,
44
        dtype=local_tensor.dtype,
45
        device=pg_device,
46
    )
47
    dist.all_gather_into_tensor(tensor, local_tensor, group=pg)
48

49
    tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())
50
    return tensor
51

52

53
def _iterate_state_dict(
54
    iter_object: Any,
55
    sharded_tensor_func: Callable,
56
    dtensor_func: Callable,
57
    *,
58
    pg: Optional[dist.ProcessGroup] = None,
59
    device: Optional[torch.device] = None,
60
    cpu_offload: bool = False,
61
    ranks_only: Tuple[int, ...] = tuple(),
62
) -> Dict[str, Any]:
63
    # TODO: should we use pytree?
64
    cpu_device = torch.device("cpu")
65
    if isinstance(iter_object, ShardedTensor):
66
        ret = sharded_tensor_func(iter_object, pg, device)
67
    elif isinstance(iter_object, DTensor):
68
        ret = dtensor_func(iter_object, pg, device)
69
    elif (
70
        isinstance(iter_object, (torch.Tensor, int, float, str)) or iter_object is None
71
    ):
72
        ret = iter_object
73
    elif isinstance(iter_object, dict):
74
        ret = {
75
            key: _iterate_state_dict(
76
                value,
77
                sharded_tensor_func,
78
                dtensor_func,
79
                pg=pg,
80
                device=device,
81
                cpu_offload=cpu_offload,
82
                ranks_only=ranks_only,
83
            )
84
            for key, value in iter_object.items()
85
        }
86
    elif isinstance(iter_object, (list, tuple)):
87
        ret = [
88
            _iterate_state_dict(
89
                v,
90
                sharded_tensor_func,
91
                dtensor_func,
92
                pg=pg,
93
                device=device,
94
                cpu_offload=cpu_offload,
95
                ranks_only=ranks_only,
96
            )
97
            for v in iter_object
98
        ]
99
        if isinstance(iter_object, tuple):
100
            ret = tuple(ret)
101
    else:
102
        raise ValueError(f"Unexpected value type {type(iter_object)}")
103

104
    if not ranks_only or dist.get_rank(pg) in ranks_only:
105
        if isinstance(ret, torch.Tensor) and cpu_offload:
106
            ret = ret.to(cpu_device)
107
    else:
108
        ret = {} if isinstance(ret, dict) else None
109

110
    return ret
111

112

113
def _gather_state_dict(
114
    state_dict: Dict[str, Any],
115
    *,
116
    pg: Optional[dist.ProcessGroup] = None,
117
    device: Optional[torch.device] = None,
118
    cpu_offload: bool = False,
119
    ranks_only: Tuple[int, ...] = tuple(),
120
) -> Dict[str, Any]:
121
    """
122
    Given a state_dict, this API gathers all the ShardedTensors or DTensors in
123
    the state_dict.
124

125

126
    Args:
127
        state_dict (Dict[str, Any]): the target sharded state_dict.
128
        pg (Optional[dist.ProcessGroup]): the process group that is used to
129
            gather ShardedTensor. Note that gathering a DTensor will use
130
            the DeviceMesh. So this argument will be ignored when gathering a
131
            DTensor.
132
        device: (Optional[torch.device]): the device that is used to
133
            perform allgather for ShardedTensor. Note that gathering a DTensor
134
            will use the DeviceMesh. So this argument will be ignored when
135
            gathering a DTensor.
136
        cpu_offload (bool): whether to offload the tensors to CPU memory. The
137
            default value is False.
138
        ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will
139
            have the same state_dicts. Otherwise only ranks that in ``ranks_only``
140
            have the same state_dicts. Other ranks will get empty state_dicts.
141

142
    Returns:
143
        The gathered state dictionary.
144
    """
145

146
    def sharded_tensor_func(value, pg, device):
147
        # ShardedTensor does not seem to record the original device type.
148
        # So if the tensor is moved to CPU, we won't know the original type.
149
        # As a result, we have to rely on the user to tell us the correct one.
150
        cpu_device = torch.device("cpu")
151
        output_tensor = _all_gather_sharded_tensor(value, pg, device)
152
        local_shard_device = (
153
            value.local_shards()[0].tensor.device
154
            if value.local_shards()
155
            else cpu_device
156
        )
157
        if output_tensor.device != local_shard_device:
158
            value = output_tensor.to(local_shard_device)
159
        else:
160
            value = output_tensor
161
        return value
162

163
    def dtensor_func(value, pg, device):
164
        if value.device != value.device_mesh.device_type:
165
            value = value.to(value.device_mesh.device_type)
166
        # FSDP all_gather: [Shard(0)] -> [Replicate()]
167
        # HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
168
        # 2D FSDP + TP all_gather:
169
        # - [Shard(0), Shard(n)] -> [Replicate(), Replicate()]
170
        # - [Shard(0), Replicate()] -> [Replicate(), Replicate()]
171
        placements = [Replicate() for _ in value.placements]
172
        value = value.redistribute(
173
            device_mesh=value.device_mesh,
174
            placements=placements,
175
        )
176
        # Call `wait()` to force the tensor to be synchronous with respect
177
        # to the main stream.
178
        # See the discussion in https://github.com/pytorch/pytorch/pull/117799.
179
        value = value.to_local()
180
        if isinstance(value, AsyncCollectiveTensor):
181
            value = value.wait()
182
        return value
183

184
    return _iterate_state_dict(
185
        state_dict,
186
        sharded_tensor_func,
187
        dtensor_func,
188
        pg=pg,
189
        device=device,
190
        cpu_offload=cpu_offload,
191
        ranks_only=ranks_only,
192
    )
193

194

195
def _offload_state_dict_to_cpu(
196
    state_dict: Dict[str, Any],
197
    *,
198
    pg: Optional[dist.ProcessGroup] = None,
199
    device: Optional[torch.device] = None,
200
    ranks_only: Tuple[int, ...] = tuple(),
201
) -> Dict[str, Any]:
202
    return _iterate_state_dict(
203
        state_dict,
204
        lambda value, pg, device: value,
205
        lambda value, pg, device: value,
206
        pg=pg,
207
        device=device,
208
        cpu_offload=True,
209
        ranks_only=ranks_only,
210
    )
211

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.