pytorch

Форк
0
236 строк · 10.4 Кб
1
import warnings
2
import torch
3
from torch.cuda import nccl
4
from torch._utils import _take_tensors, _flatten_dense_tensors, \
5
    _unflatten_dense_tensors, _reorder_tensors_as, _get_device_index, _handle_complex
6
from typing import List
7

8
def broadcast(tensor, devices=None, *, out=None):
9
    r"""Broadcasts a tensor to specified GPU devices.
10

11
    Args:
12
        tensor (Tensor): tensor to broadcast. Can be on CPU or GPU.
13
        devices (Iterable[torch.device, str or int], optional): an iterable of
14
          GPU devices, among which to broadcast.
15
        out (Sequence[Tensor], optional, keyword-only): the GPU tensors to
16
          store output results.
17

18
    .. note::
19
        Exactly one of :attr:`devices` and :attr:`out` must be specified.
20

21
    Returns:
22
        - If :attr:`devices` is specified,
23
            a tuple containing copies of :attr:`tensor`, placed on
24
            :attr:`devices`.
25
        - If :attr:`out` is specified,
26
            a tuple containing :attr:`out` tensors, each containing a copy of
27
            :attr:`tensor`.
28
    """
29
    tensor = _handle_complex(tensor)
30
    if not ((devices is None) ^ (out is None)):
31
        raise RuntimeError(
32
            f"Exactly one of 'devices' and 'out' must be specified, but got devices={devices} and out={out}")
33
    if devices is not None:
34
        devices = [_get_device_index(d) for d in devices]
35
        return torch._C._broadcast(tensor, devices)
36
    else:
37
        return torch._C._broadcast_out(tensor, out)
38

39

40
def broadcast_coalesced(tensors, devices, buffer_size=10485760):
41
    """Broadcast a sequence of tensors to the specified GPUs.
42

43
    Small tensors are first coalesced into a buffer to reduce the number of synchronizations.
44

45
    Args:
46
        tensors (sequence): tensors to broadcast. Must be on the same device,
47
          either CPU or GPU.
48
        devices (Iterable[torch.device, str or int]): an iterable of GPU
49
          devices, among which to broadcast.
50
        buffer_size (int): maximum size of the buffer used for coalescing
51

52
    Returns:
53
        A tuple containing copies of :attr:`tensor`, placed on :attr:`devices`.
54
    """
55
    devices = [_get_device_index(d) for d in devices]
56
    tensors = [_handle_complex(t) for t in tensors]
57
    return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
58

59

60
def reduce_add(inputs, destination=None):
61
    """Sum tensors from multiple GPUs.
62

63
    All inputs should have matching shapes, dtype, and layout. The output tensor
64
    will be of the same shape, dtype, and layout.
65

66
    Args:
67
        inputs (Iterable[Tensor]): an iterable of tensors to add.
68
        destination (int, optional): a device on which the output will be
69
            placed (default: current device).
70

71
    Returns:
72
        A tensor containing an elementwise sum of all inputs, placed on the
73
        :attr:`destination` device.
74
    """
75
    destination = _get_device_index(destination, optional=True)
76
    input_size = inputs[0].size()
77
    root_index = None  # index of input tensor that already is on the correct device
78
    for i, inp in enumerate(inputs):
79
        assert inp.device.type != "cpu", "reduce_add expects all inputs to be on GPUs"
80
        if inp.get_device() == destination:
81
            root_index = i
82
        if inp.size() != input_size:
83
            got = 'x'.join(str(x) for x in inp.size())
84
            expected = 'x'.join(str(x) for x in input_size)
85
            raise ValueError(f"input {i} has invalid size: got {got}, but expected {expected}")
86
    if root_index is None:
87
        raise RuntimeError("reduce_add expects destination to be on the same GPU with one of the tensors")
88

89
    if len(inputs) == 1:
90
        return inputs[0]
91

92
    if nccl.is_available(inputs):
93
        result = torch.empty_like(inputs[root_index])
94
        nccl.reduce(inputs, output=result, root=root_index)
95
    else:
96
        destination_device = torch.device(inputs[root_index].device.type, destination)
97
        nonroot = [t for i, t in enumerate(inputs) if i != root_index]
98
        # make a new tensor w/o clone
99
        result = inputs[root_index] + nonroot[0].to(device=destination_device, non_blocking=True)
100
        for other in nonroot[1:]:
101
            result.add_(other.to(device=destination_device, non_blocking=True))
102
    return result
103

104

105
def reduce_add_coalesced(inputs, destination=None, buffer_size=10485760):
106
    """Sum tensors from multiple GPUs.
107

108
    Small tensors are first coalesced into a buffer to reduce the number
109
    of synchronizations.
110

111
    Args:
112
        inputs (Iterable[Iterable[Tensor]]): iterable of iterables that
113
            contain tensors from a single device.
114
        destination (int, optional): a device on which the output will be
115
            placed (default: current device).
116
        buffer_size (int): maximum size of the buffer used for coalescing
117

118
    Returns:
119
        A tuple of tensors containing an elementwise sum of each group of
120
        inputs, placed on the ``destination`` device.
121
    """
122
    # TODO: When `len(inputs) == 1` and all inputs are on `destination`, just
123
    #       return `inputs`.
124
    dense_tensors: List[List] = [[] for _ in inputs]  # shape (num_gpus, num_tensors)
125
    output = []
126
    ref_order = []
127
    # process sparse ones first since they may have different sizes on different gpus
128
    for tensor_at_gpus in zip(*inputs):
129
        if all(t.is_sparse for t in tensor_at_gpus):
130
            result = reduce_add(tensor_at_gpus, destination)  # this will be sparse too
131
            output.append(result)
132
            ref_order.append(tensor_at_gpus[0])
133
        else:
134
            for coll, t in zip(dense_tensors, tensor_at_gpus):
135
                coll.append(t.to_dense() if t.is_sparse else t)
136
            ref_order.append(dense_tensors[0][-1])
137
    itrs = [_take_tensors(tensors, buffer_size) for tensors in dense_tensors]
138
    # now the dense ones, which have consistent sizes
139
    for chunks in zip(*itrs):
140
        flat_tensors = [_flatten_dense_tensors(chunk) for chunk in chunks]  # (num_gpus,)
141
        flat_result = reduce_add(flat_tensors, destination)
142
        for t in _unflatten_dense_tensors(flat_result, chunks[0]):
143
            # The unflattened tensors do not share storage, and we don't expose
144
            # base flat tensor anyways, so give them different version counters.
145
            # See NOTE [ Version Counter in comm.*_coalesced ]
146
            output.append(t.data)
147
    return tuple(_reorder_tensors_as(output, ref_order))
148

149

150
def scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out=None):
151
    """Scatters tensor across multiple GPUs.
152

153
    Args:
154
        tensor (Tensor): tensor to scatter. Can be on CPU or GPU.
155
        devices (Iterable[torch.device, str or int], optional): an iterable of
156
          GPU devices, among which to scatter.
157
        chunk_sizes (Iterable[int], optional): sizes of chunks to be placed on
158
          each device. It should match :attr:`devices` in length and sums to
159
          ``tensor.size(dim)``. If not specified, :attr:`tensor` will be divided
160
          into equal chunks.
161
        dim (int, optional): A dimension along which to chunk :attr:`tensor`.
162
          Default: ``0``.
163
        streams (Iterable[torch.cuda.Stream], optional): an iterable of Streams, among
164
          which to execute the scatter. If not specified, the default stream will
165
          be utilized.
166
        out (Sequence[Tensor], optional, keyword-only): the GPU tensors to
167
          store output results. Sizes of these tensors must match that of
168
          :attr:`tensor`, except for :attr:`dim`, where the total size must
169
          sum to ``tensor.size(dim)``.
170

171
    .. note::
172
        Exactly one of :attr:`devices` and :attr:`out` must be specified. When
173
        :attr:`out` is specified, :attr:`chunk_sizes` must not be specified and
174
        will be inferred from sizes of :attr:`out`.
175

176
    Returns:
177
        - If :attr:`devices` is specified,
178
            a tuple containing chunks of :attr:`tensor`, placed on
179
            :attr:`devices`.
180
        - If :attr:`out` is specified,
181
            a tuple containing :attr:`out` tensors, each containing a chunk of
182
            :attr:`tensor`.
183
    """
184
    tensor = _handle_complex(tensor)
185
    if out is None:
186
        devices = [_get_device_index(d) for d in devices]
187
        return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))
188
    else:
189
        if devices is not None:
190
            raise RuntimeError(
191
                f"'devices' must not be specified when 'out' is specified, but got devices={devices}")
192
        if chunk_sizes is not None:
193
            raise RuntimeError(
194
                f"'chunk_sizes' must not be specified when 'out' is specified, but got chunk_sizes={chunk_sizes}")
195
        return tuple(torch._C._scatter_out(tensor, out, dim, streams))
196

197

198
def gather(tensors, dim=0, destination=None, *, out=None):
199
    r"""Gathers tensors from multiple GPU devices.
200

201
    Args:
202
        tensors (Iterable[Tensor]): an iterable of tensors to gather.
203
          Tensor sizes in all dimensions other than :attr:`dim` have to match.
204
        dim (int, optional): a dimension along which the tensors will be
205
          concatenated. Default: ``0``.
206
        destination (torch.device, str, or int, optional): the output device.
207
          Can be CPU or CUDA. Default: the current CUDA device.
208
        out (Tensor, optional, keyword-only): the tensor to store gather result.
209
          Its sizes must match those of :attr:`tensors`, except for :attr:`dim`,
210
          where the size must equal ``sum(tensor.size(dim) for tensor in tensors)``.
211
          Can be on CPU or CUDA.
212

213
    .. note::
214
        :attr:`destination` must not be specified when :attr:`out` is specified.
215

216
    Returns:
217
        - If :attr:`destination` is specified,
218
            a tensor located on :attr:`destination` device, that is a result of
219
            concatenating :attr:`tensors` along :attr:`dim`.
220
        - If :attr:`out` is specified,
221
            the :attr:`out` tensor, now containing results of concatenating
222
            :attr:`tensors` along :attr:`dim`.
223
    """
224
    tensors = [_handle_complex(t) for t in tensors]
225
    if out is None:
226
        if destination == -1:
227
            warnings.warn(
228
                'Using -1 to represent CPU tensor is deprecated. Please use a '
229
                'device object or string instead, e.g., "cpu".')
230
        destination = _get_device_index(destination, allow_cpu=True, optional=True)
231
        return torch._C._gather(tensors, dim, destination)
232
    else:
233
        if destination is not None:
234
            raise RuntimeError(
235
                f"'destination' must not be specified when 'out' is specified, but got destination={destination}")
236
        return torch._C._gather_out(tensors, out, dim)
237

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.