pytorch

Форк
0
/
_shard_utils.py 
127 строк · 4.2 Кб
1
import copy
2
import itertools
3
import math
4
from typing import Optional
5

6
import torch
7
import torch.distributed as dist
8
from torch.distributed import distributed_c10d
9
from torch.distributed._shard.sharded_tensor import (
10
    Shard,
11
    ShardedTensor,
12
    ShardedTensorMetadata,
13
    TensorProperties,
14
)
15
from torch.distributed._shard.sharding_spec import ShardMetadata
16
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard as DShard
17

18

19
def _get_remote_device_str(rank, device_type, num_devices_per_node):
20
    if device_type.lower() == "cpu":
21
        return f"rank:{rank}/{device_type}"
22
    else:
23
        return f"rank:{rank}/{device_type}:{rank % num_devices_per_node}"
24

25

26
def _create_chunk_sharded_tensor(
27
    tensor: torch.Tensor,
28
    rank: int,
29
    world_size: int,
30
    num_devices_per_node: int,
31
    pg: dist.ProcessGroup,
32
    device: Optional[torch.device] = None,
33
) -> ShardedTensor:
34
    """
35
    Shard a tensor to chunks along the first dimension. The local rank will gets its
36
    corresponding chunk as the local shard to create a ShardedTensor.
37
    """
38
    chunks = tensor.chunk(world_size, dim=0)
39
    if len(chunks) > rank:
40
        local_shard = chunks[rank].clone()
41
        offsets = [0 for _ in tensor.size()]
42
        offsets[0] = math.ceil(tensor.size()[0] / world_size) * rank
43
        local_shards = [Shard.from_tensor_and_offsets(local_shard, offsets, rank)]
44
    else:
45
        local_shards = []
46

47
    # Create a ShardedTensor without invoking communication.
48
    chunk_sizes = [list(chunk.size()) for chunk in chunks]
49
    dim0_offsets = [0] + list(
50
        itertools.accumulate([chunk_size[0] for chunk_size in chunk_sizes])
51
    )[:-1]
52
    offsets = [0] * (len(chunk_sizes[0]) - 1)
53
    chunk_offsets = [[d0] + offsets for d0 in dim0_offsets]
54
    device_type = (
55
        distributed_c10d._get_pg_default_device(pg).type
56
        if device is None
57
        else device.type
58
    )
59
    placements = [
60
        _get_remote_device_str(r, device_type, num_devices_per_node)
61
        for r in range(len(chunk_sizes))
62
    ]
63
    assert len(chunk_sizes) == len(chunk_offsets) == len(placements)
64
    shard_metadata = [
65
        ShardMetadata(offset, size, placement)
66
        for offset, size, placement in zip(chunk_offsets, chunk_sizes, placements)
67
    ]
68
    sharded_tensor_metadata = ShardedTensorMetadata(
69
        shards_metadata=shard_metadata,
70
        size=tensor.size(),
71
        tensor_properties=TensorProperties(
72
            dtype=tensor.dtype,
73
            layout=tensor.layout,
74
            requires_grad=False,
75
            memory_format=torch.contiguous_format,
76
            pin_memory=tensor.is_pinned(),
77
        ),
78
    )
79
    return ShardedTensor._init_from_local_shards_and_global_metadata(
80
        local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=pg
81
    )
82

83

84
def _create_chunk_dtensor(
85
    tensor: torch.Tensor,
86
    rank: int,
87
    device_mesh: DeviceMesh,
88
) -> DTensor:
89
    """
90
    Shard a tensor to chunks along the first dimension. The local rank will gets its
91
    corresponding chunk as the local tensor to create a DTensor.
92
    """
93
    # We need to explicitly call .detach() to return a new tensor detached from the current graph.
94
    tensor = tensor.clone().detach()
95

96
    # FSDP placements: [Shard(0)]
97
    # HSDP placements: [Replicate(), Shard(0)]
98
    replicate_placements = [Replicate() for _ in range(device_mesh.ndim)]
99
    shard_placements = [Replicate() for _ in range(device_mesh.ndim)]
100
    shard_placements[-1] = DShard(0)  # type: ignore[call-overload]
101

102
    return DTensor.from_local(
103
        tensor, device_mesh, replicate_placements, run_check=False
104
    ).redistribute(
105
        placements=shard_placements,
106
    )
107

108

109
def _all_gather_dtensor(
110
    tensor: DTensor,
111
    parent_mesh: Optional[DeviceMesh],
112
) -> torch.Tensor:
113
    """
114
    All gather a DTensor in its sharded dimension and return the local tensor.
115
    """
116
    assert parent_mesh is None
117

118
    placements = list(copy.deepcopy(tensor.placements))
119
    # FSDP placements: [Shard(0)] -> [Replicate()]
120
    # HSDP placements: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
121
    placements[-1] = Replicate()
122
    tensor = tensor.redistribute(
123
        device_mesh=tensor.device_mesh,
124
        placements=placements,
125
    )
126

127
    return tensor.to_local()
128

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.