pytorch

Форк
0
/
_functions.py 
126 строк · 4.7 Кб
1
import warnings
2

3
import torch
4
from . import comm
5
from torch.autograd import Function
6
from torch._utils import _get_device_index
7
from typing import List, Optional
8

9

10
class Broadcast(Function):
11

12
    @staticmethod
13
    def forward(ctx, target_gpus, *inputs):
14
        assert all(i.device.type != 'cpu' for i in inputs), (
15
            'Broadcast function not implemented for CPU tensors'
16
        )
17
        target_gpus = [_get_device_index(x, True) for x in target_gpus]
18
        ctx.target_gpus = target_gpus
19
        if len(inputs) == 0:
20
            return tuple()
21
        ctx.num_inputs = len(inputs)
22
        ctx.input_device = inputs[0].get_device()
23
        outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
24
        non_differentiables = []
25
        for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]):
26
            if not input_requires_grad:
27
                for output in outputs:
28
                    non_differentiables.append(output[idx])
29
        ctx.mark_non_differentiable(*non_differentiables)
30
        return tuple([t for tensors in outputs for t in tensors])
31

32
    @staticmethod
33
    def backward(ctx, *grad_outputs):
34
        return (None,) + ReduceAddCoalesced.apply(ctx.input_device, ctx.num_inputs, *grad_outputs)
35

36

37
class ReduceAddCoalesced(Function):
38

39
    @staticmethod
40
    def forward(ctx, destination, num_inputs, *grads):
41
        ctx.target_gpus = [grads[i].get_device() for i in range(0, len(grads), num_inputs)]
42

43
        grads_ = [grads[i:i + num_inputs]
44
                  for i in range(0, len(grads), num_inputs)]
45
        return comm.reduce_add_coalesced(grads_, destination)
46

47
    @staticmethod
48
    def backward(ctx, *grad_outputs):
49
        return (None, None,) + Broadcast.apply(ctx.target_gpus, *grad_outputs)
50

51

52
class Gather(Function):
53

54
    @staticmethod
55
    def forward(ctx, target_device, dim, *inputs):
56
        assert all(i.device.type != 'cpu' for i in inputs), (
57
            'Gather function not implemented for CPU tensors'
58
        )
59
        if (target_device == 'cpu'):
60
            ctx.target_device = 'cpu'
61
        else:
62
            target_device = _get_device_index(target_device, True)
63
            ctx.target_device = target_device
64
        ctx.dim = dim
65
        ctx.input_gpus = tuple(i.get_device() for i in inputs)
66
        if all(t.dim() == 0 for t in inputs) and dim == 0:
67
            inputs = tuple(t.view(1) for t in inputs)
68
            warnings.warn('Was asked to gather along dimension 0, but all '
69
                          'input tensors were scalars; will instead unsqueeze '
70
                          'and return a vector.')
71
            ctx.unsqueezed_scalar = True
72
        else:
73
            ctx.unsqueezed_scalar = False
74
        ctx.input_sizes = tuple(i.size(ctx.dim) for i in inputs)
75
        return comm.gather(inputs, ctx.dim, ctx.target_device)
76

77
    @staticmethod
78
    def backward(ctx, grad_output):
79
        scattered_grads = Scatter.apply(ctx.input_gpus, ctx.input_sizes, ctx.dim, grad_output)
80
        if ctx.unsqueezed_scalar:
81
            scattered_grads = tuple(g[0] for g in scattered_grads)
82
        return (None, None) + scattered_grads
83

84

85
class Scatter(Function):
86

87
    @staticmethod
88
    def forward(ctx, target_gpus, chunk_sizes, dim, input):
89
        target_gpus = [_get_device_index(x, True) for x in target_gpus]
90
        ctx.dim = dim
91
        ctx.input_device = input.get_device() if input.device.type != "cpu" else -1
92
        streams = None
93
        if torch.cuda.is_available() and ctx.input_device == -1:
94
            # Perform CPU to GPU copies in a background stream
95
            streams = [_get_stream(torch.device("cuda", device)) for device in target_gpus]
96
        outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams)
97
        # Synchronize with the copy stream
98
        if streams is not None:
99
            for i, output in enumerate(outputs):
100
                with torch.cuda.device(target_gpus[i]):
101
                    main_stream = torch.cuda.current_stream()
102
                    main_stream.wait_stream(streams[i])
103
                    output.record_stream(main_stream)
104
        return outputs
105

106
    @staticmethod
107
    def backward(ctx, *grad_output):
108
        return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output)
109

110

111
# background streams used for copying
112
_streams: Optional[List[Optional[torch.Stream]]] = None
113

114
def _get_stream(device: torch.device):
115
    """Get a background stream for copying between CPU and target device."""
116
    global _streams
117
    if device.type == "cpu":
118
        return None
119
    device_mod = getattr(torch, device.type, None)
120
    if device_mod is None:
121
        return None
122
    if _streams is None:
123
        _streams = [None] * device_mod.device_count()
124
    if _streams[device.index] is None:
125
        _streams[device.index] = device_mod.Stream(device.index)
126
    return _streams[device.index]
127

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.