pytorch
210 строк · 7.3 Кб
1import math
2from typing import Any, Callable, Dict, Optional, Tuple, TYPE_CHECKING
3
4import torch
5import torch.distributed as dist
6import torch.nn.functional as F
7from torch.distributed._functional_collectives import AsyncCollectiveTensor
8
9if dist.is_available() or TYPE_CHECKING:
10from torch.distributed import distributed_c10d
11from torch.distributed._shard.sharded_tensor import ShardedTensor
12from torch.distributed._tensor import DTensor, Replicate
13
14
15def _all_gather_sharded_tensor(
16sharded_tensor: "ShardedTensor",
17pg: Optional[dist.ProcessGroup] = None,
18device: Optional[torch.device] = None,
19) -> torch.Tensor:
20if pg is None:
21pg = distributed_c10d._get_default_group()
22world_size = dist.get_world_size(pg)
23shards = sharded_tensor.local_shards()
24dim_0_size = sharded_tensor.size()[0] # type: ignore[index]
25tensor_numel = sharded_tensor.size().numel() # type: ignore[union-attr]
26chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size
27pg_device = (
28distributed_c10d._get_pg_default_device(pg) if device is None else device
29)
30if shards:
31local_tensor = shards[0].tensor.flatten()
32if local_tensor.device.type != pg_device.type:
33local_tensor = local_tensor.to(pg_device)
34num_padding = chunk_size - local_tensor.numel()
35if num_padding > 0:
36local_tensor = F.pad(local_tensor, [0, num_padding])
37else:
38local_tensor = torch.zeros(
39chunk_size, dtype=sharded_tensor.dtype, device=pg_device
40)
41
42tensor = torch.empty(
43chunk_size * world_size,
44dtype=local_tensor.dtype,
45device=pg_device,
46)
47dist.all_gather_into_tensor(tensor, local_tensor, group=pg)
48
49tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())
50return tensor
51
52
53def _iterate_state_dict(
54iter_object: Any,
55sharded_tensor_func: Callable,
56dtensor_func: Callable,
57*,
58pg: Optional[dist.ProcessGroup] = None,
59device: Optional[torch.device] = None,
60cpu_offload: bool = False,
61ranks_only: Tuple[int, ...] = tuple(),
62) -> Dict[str, Any]:
63# TODO: should we use pytree?
64cpu_device = torch.device("cpu")
65if isinstance(iter_object, ShardedTensor):
66ret = sharded_tensor_func(iter_object, pg, device)
67elif isinstance(iter_object, DTensor):
68ret = dtensor_func(iter_object, pg, device)
69elif (
70isinstance(iter_object, (torch.Tensor, int, float, str)) or iter_object is None
71):
72ret = iter_object
73elif isinstance(iter_object, dict):
74ret = {
75key: _iterate_state_dict(
76value,
77sharded_tensor_func,
78dtensor_func,
79pg=pg,
80device=device,
81cpu_offload=cpu_offload,
82ranks_only=ranks_only,
83)
84for key, value in iter_object.items()
85}
86elif isinstance(iter_object, (list, tuple)):
87ret = [
88_iterate_state_dict(
89v,
90sharded_tensor_func,
91dtensor_func,
92pg=pg,
93device=device,
94cpu_offload=cpu_offload,
95ranks_only=ranks_only,
96)
97for v in iter_object
98]
99if isinstance(iter_object, tuple):
100ret = tuple(ret)
101else:
102raise ValueError(f"Unexpected value type {type(iter_object)}")
103
104if not ranks_only or dist.get_rank(pg) in ranks_only:
105if isinstance(ret, torch.Tensor) and cpu_offload:
106ret = ret.to(cpu_device)
107else:
108ret = {} if isinstance(ret, dict) else None
109
110return ret
111
112
113def _gather_state_dict(
114state_dict: Dict[str, Any],
115*,
116pg: Optional[dist.ProcessGroup] = None,
117device: Optional[torch.device] = None,
118cpu_offload: bool = False,
119ranks_only: Tuple[int, ...] = tuple(),
120) -> Dict[str, Any]:
121"""
122Given a state_dict, this API gathers all the ShardedTensors or DTensors in
123the state_dict.
124
125
126Args:
127state_dict (Dict[str, Any]): the target sharded state_dict.
128pg (Optional[dist.ProcessGroup]): the process group that is used to
129gather ShardedTensor. Note that gathering a DTensor will use
130the DeviceMesh. So this argument will be ignored when gathering a
131DTensor.
132device: (Optional[torch.device]): the device that is used to
133perform allgather for ShardedTensor. Note that gathering a DTensor
134will use the DeviceMesh. So this argument will be ignored when
135gathering a DTensor.
136cpu_offload (bool): whether to offload the tensors to CPU memory. The
137default value is False.
138ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will
139have the same state_dicts. Otherwise only ranks that in ``ranks_only``
140have the same state_dicts. Other ranks will get empty state_dicts.
141
142Returns:
143The gathered state dictionary.
144"""
145
146def sharded_tensor_func(value, pg, device):
147# ShardedTensor does not seem to record the original device type.
148# So if the tensor is moved to CPU, we won't know the original type.
149# As a result, we have to rely on the user to tell us the correct one.
150cpu_device = torch.device("cpu")
151output_tensor = _all_gather_sharded_tensor(value, pg, device)
152local_shard_device = (
153value.local_shards()[0].tensor.device
154if value.local_shards()
155else cpu_device
156)
157if output_tensor.device != local_shard_device:
158value = output_tensor.to(local_shard_device)
159else:
160value = output_tensor
161return value
162
163def dtensor_func(value, pg, device):
164if value.device != value.device_mesh.device_type:
165value = value.to(value.device_mesh.device_type)
166# FSDP all_gather: [Shard(0)] -> [Replicate()]
167# HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
168# 2D FSDP + TP all_gather:
169# - [Shard(0), Shard(n)] -> [Replicate(), Replicate()]
170# - [Shard(0), Replicate()] -> [Replicate(), Replicate()]
171placements = [Replicate() for _ in value.placements]
172value = value.redistribute(
173device_mesh=value.device_mesh,
174placements=placements,
175)
176# Call `wait()` to force the tensor to be synchronous with respect
177# to the main stream.
178# See the discussion in https://github.com/pytorch/pytorch/pull/117799.
179value = value.to_local()
180if isinstance(value, AsyncCollectiveTensor):
181value = value.wait()
182return value
183
184return _iterate_state_dict(
185state_dict,
186sharded_tensor_func,
187dtensor_func,
188pg=pg,
189device=device,
190cpu_offload=cpu_offload,
191ranks_only=ranks_only,
192)
193
194
195def _offload_state_dict_to_cpu(
196state_dict: Dict[str, Any],
197*,
198pg: Optional[dist.ProcessGroup] = None,
199device: Optional[torch.device] = None,
200ranks_only: Tuple[int, ...] = tuple(),
201) -> Dict[str, Any]:
202return _iterate_state_dict(
203state_dict,
204lambda value, pg, device: value,
205lambda value, pg, device: value,
206pg=pg,
207device=device,
208cpu_offload=True,
209ranks_only=ranks_only,
210)
211