pytorch

Форк
0
/
parallel_apply.py 
110 строк · 4.2 Кб
1
import threading
2
import torch
3
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
4
from ..modules import Module
5
from torch.cuda._utils import _get_device_index
6
from torch.cuda.amp import autocast
7
from torch._utils import ExceptionWrapper
8

9
__all__ = ['get_a_var', 'parallel_apply']
10

11
def get_a_var(obj: Union[torch.Tensor, List[Any], Tuple[Any, ...], Dict[Any, Any]]) -> Optional[torch.Tensor]:
12
    if isinstance(obj, torch.Tensor):
13
        return obj
14

15
    if isinstance(obj, (list, tuple)):
16
        for result in map(get_a_var, obj):
17
            if isinstance(result, torch.Tensor):
18
                return result
19
    if isinstance(obj, dict):
20
        for result in map(get_a_var, obj.items()):
21
            if isinstance(result, torch.Tensor):
22
                return result
23
    return None
24

25
def parallel_apply(
26
    modules: Sequence[Module],
27
    inputs: Sequence[Any],
28
    kwargs_tup: Optional[Sequence[Dict[str, Any]]] = None,
29
    devices: Optional[Sequence[Optional[Union[int, torch.device]]]] = None,
30
) -> List[Any]:
31
    r"""Apply each `module` in :attr:`modules` in parallel on each of :attr:`devices`.
32

33
    Args:
34
        modules (Module): modules to be parallelized
35
        inputs (tensor): inputs to the modules
36
        devices (list of int or torch.device): CUDA devices
37

38
    :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
39
    :attr:`devices` (if given) should all have same length. Moreover, each
40
    element of :attr:`inputs` can either be a single object as the only argument
41
    to a module, or a collection of positional arguments.
42
    """
43
    assert len(modules) == len(inputs), f'The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}'
44
    if kwargs_tup is not None:
45
        assert len(modules) == len(kwargs_tup)
46
    else:
47
        kwargs_tup = (cast(Dict[str, Any], {}),) * len(modules)
48
    if devices is not None:
49
        assert len(modules) == len(devices)
50
    else:
51
        devices = [None] * len(modules)
52
    devices = [_get_device_index(x, True) for x in devices]
53
    streams = [torch.cuda.current_stream(x) for x in devices]
54
    lock = threading.Lock()
55
    results = {}
56
    grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled()
57

58
    def _worker(
59
        i: int,
60
        module: Module,
61
        input: Any,
62
        kwargs: Dict[str, Any],
63
        device: Optional[Union[int, torch.device]] = None,
64
        stream: Optional[torch.cuda.Stream] = None,
65
    ) -> None:
66
        torch.set_grad_enabled(grad_enabled)
67
        if device is None:
68
            t = get_a_var(input)
69
            if t is None:
70
                with lock:
71
                    results[i] = ExceptionWrapper(
72
                        where=f"in replica {i}, no device was provided and no tensor input was found; "
73
                        "device cannot be resolved")
74
                return
75
            device = t.get_device()
76
        if stream is None:
77
            stream = torch.cuda.current_stream(device)
78
        try:
79
            with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
80
                # this also avoids accidental slicing of `input` if it is a Tensor
81
                if not isinstance(input, (list, tuple)):
82
                    input = (input,)
83
                output = module(*input, **kwargs)
84
            with lock:
85
                results[i] = output
86
        except Exception:
87
            with lock:
88
                results[i] = ExceptionWrapper(
89
                    where=f"in replica {i} on device {device}")
90

91
    if len(modules) > 1:
92
        threads = [threading.Thread(target=_worker,
93
                                    args=(i, module, input, kwargs, device, stream))
94
                   for i, (module, input, kwargs, device, stream) in
95
                   enumerate(zip(modules, inputs, kwargs_tup, devices, streams))]
96

97
        for thread in threads:
98
            thread.start()
99
        for thread in threads:
100
            thread.join()
101
    else:
102
        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0], streams[0])
103

104
    outputs = []
105
    for i in range(len(inputs)):
106
        output = results[i]
107
        if isinstance(output, ExceptionWrapper):
108
            output.reraise()
109
        outputs.append(output)
110
    return outputs
111

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.