pytorch
127 строк · 4.2 Кб
1import copy2import itertools3import math4from typing import Optional5
6import torch7import torch.distributed as dist8from torch.distributed import distributed_c10d9from torch.distributed._shard.sharded_tensor import (10Shard,11ShardedTensor,12ShardedTensorMetadata,13TensorProperties,14)
15from torch.distributed._shard.sharding_spec import ShardMetadata16from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard as DShard17
18
19def _get_remote_device_str(rank, device_type, num_devices_per_node):20if device_type.lower() == "cpu":21return f"rank:{rank}/{device_type}"22else:23return f"rank:{rank}/{device_type}:{rank % num_devices_per_node}"24
25
26def _create_chunk_sharded_tensor(27tensor: torch.Tensor,28rank: int,29world_size: int,30num_devices_per_node: int,31pg: dist.ProcessGroup,32device: Optional[torch.device] = None,33) -> ShardedTensor:34"""35Shard a tensor to chunks along the first dimension. The local rank will gets its
36corresponding chunk as the local shard to create a ShardedTensor.
37"""
38chunks = tensor.chunk(world_size, dim=0)39if len(chunks) > rank:40local_shard = chunks[rank].clone()41offsets = [0 for _ in tensor.size()]42offsets[0] = math.ceil(tensor.size()[0] / world_size) * rank43local_shards = [Shard.from_tensor_and_offsets(local_shard, offsets, rank)]44else:45local_shards = []46
47# Create a ShardedTensor without invoking communication.48chunk_sizes = [list(chunk.size()) for chunk in chunks]49dim0_offsets = [0] + list(50itertools.accumulate([chunk_size[0] for chunk_size in chunk_sizes])51)[:-1]52offsets = [0] * (len(chunk_sizes[0]) - 1)53chunk_offsets = [[d0] + offsets for d0 in dim0_offsets]54device_type = (55distributed_c10d._get_pg_default_device(pg).type56if device is None57else device.type58)59placements = [60_get_remote_device_str(r, device_type, num_devices_per_node)61for r in range(len(chunk_sizes))62]63assert len(chunk_sizes) == len(chunk_offsets) == len(placements)64shard_metadata = [65ShardMetadata(offset, size, placement)66for offset, size, placement in zip(chunk_offsets, chunk_sizes, placements)67]68sharded_tensor_metadata = ShardedTensorMetadata(69shards_metadata=shard_metadata,70size=tensor.size(),71tensor_properties=TensorProperties(72dtype=tensor.dtype,73layout=tensor.layout,74requires_grad=False,75memory_format=torch.contiguous_format,76pin_memory=tensor.is_pinned(),77),78)79return ShardedTensor._init_from_local_shards_and_global_metadata(80local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=pg81)82
83
84def _create_chunk_dtensor(85tensor: torch.Tensor,86rank: int,87device_mesh: DeviceMesh,88) -> DTensor:89"""90Shard a tensor to chunks along the first dimension. The local rank will gets its
91corresponding chunk as the local tensor to create a DTensor.
92"""
93# We need to explicitly call .detach() to return a new tensor detached from the current graph.94tensor = tensor.clone().detach()95
96# FSDP placements: [Shard(0)]97# HSDP placements: [Replicate(), Shard(0)]98replicate_placements = [Replicate() for _ in range(device_mesh.ndim)]99shard_placements = [Replicate() for _ in range(device_mesh.ndim)]100shard_placements[-1] = DShard(0) # type: ignore[call-overload]101
102return DTensor.from_local(103tensor, device_mesh, replicate_placements, run_check=False104).redistribute(105placements=shard_placements,106)107
108
109def _all_gather_dtensor(110tensor: DTensor,111parent_mesh: Optional[DeviceMesh],112) -> torch.Tensor:113"""114All gather a DTensor in its sharded dimension and return the local tensor.
115"""
116assert parent_mesh is None117
118placements = list(copy.deepcopy(tensor.placements))119# FSDP placements: [Shard(0)] -> [Replicate()]120# HSDP placements: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]121placements[-1] = Replicate()122tensor = tensor.redistribute(123device_mesh=tensor.device_mesh,124placements=placements,125)126
127return tensor.to_local()128