3
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
4
from ..modules import Module
5
from torch.cuda._utils import _get_device_index
6
from torch.cuda.amp import autocast
7
from torch._utils import ExceptionWrapper
9
__all__ = ['get_a_var', 'parallel_apply']
11
def get_a_var(obj: Union[torch.Tensor, List[Any], Tuple[Any, ...], Dict[Any, Any]]) -> Optional[torch.Tensor]:
12
if isinstance(obj, torch.Tensor):
15
if isinstance(obj, (list, tuple)):
16
for result in map(get_a_var, obj):
17
if isinstance(result, torch.Tensor):
19
if isinstance(obj, dict):
20
for result in map(get_a_var, obj.items()):
21
if isinstance(result, torch.Tensor):
26
modules: Sequence[Module],
27
inputs: Sequence[Any],
28
kwargs_tup: Optional[Sequence[Dict[str, Any]]] = None,
29
devices: Optional[Sequence[Optional[Union[int, torch.device]]]] = None,
31
r"""Apply each `module` in :attr:`modules` in parallel on each of :attr:`devices`.
34
modules (Module): modules to be parallelized
35
inputs (tensor): inputs to the modules
36
devices (list of int or torch.device): CUDA devices
38
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
39
:attr:`devices` (if given) should all have same length. Moreover, each
40
element of :attr:`inputs` can either be a single object as the only argument
41
to a module, or a collection of positional arguments.
43
assert len(modules) == len(inputs), f'The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}'
44
if kwargs_tup is not None:
45
assert len(modules) == len(kwargs_tup)
47
kwargs_tup = (cast(Dict[str, Any], {}),) * len(modules)
48
if devices is not None:
49
assert len(modules) == len(devices)
51
devices = [None] * len(modules)
52
devices = [_get_device_index(x, True) for x in devices]
53
streams = [torch.cuda.current_stream(x) for x in devices]
54
lock = threading.Lock()
56
grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled()
62
kwargs: Dict[str, Any],
63
device: Optional[Union[int, torch.device]] = None,
64
stream: Optional[torch.cuda.Stream] = None,
66
torch.set_grad_enabled(grad_enabled)
71
results[i] = ExceptionWrapper(
72
where=f"in replica {i}, no device was provided and no tensor input was found; "
73
"device cannot be resolved")
75
device = t.get_device()
77
stream = torch.cuda.current_stream(device)
79
with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
80
# this also avoids accidental slicing of `input` if it is a Tensor
81
if not isinstance(input, (list, tuple)):
83
output = module(*input, **kwargs)
88
results[i] = ExceptionWrapper(
89
where=f"in replica {i} on device {device}")
92
threads = [threading.Thread(target=_worker,
93
args=(i, module, input, kwargs, device, stream))
94
for i, (module, input, kwargs, device, stream) in
95
enumerate(zip(modules, inputs, kwargs_tup, devices, streams))]
97
for thread in threads:
99
for thread in threads:
102
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0], streams[0])
105
for i in range(len(inputs)):
107
if isinstance(output, ExceptionWrapper):
109
outputs.append(output)