pytorch

Форк
0
452 строки · 17.9 Кб
1
# mypy: ignore-errors
2

3
# Copyright (c) Facebook, Inc. and its affiliates.
4
# All rights reserved.
5
#
6
# This source code is licensed under the BSD-style license found in the
7
# LICENSE file in the root directory of this source tree.
8

9
import torch
10
import contextlib
11
import functools
12
import threading
13
from torch import Tensor
14
from typing import Any, Callable, Optional, Tuple, Union, List
15
from torch.utils._pytree import (
16
    tree_flatten,
17
    tree_unflatten,
18
    tree_map_,
19
    _broadcast_to_and_flatten,
20
    TreeSpec,
21
)
22
from functools import partial
23
import os
24
import itertools
25

26
from torch._C._functorch import (
27
    _add_batch_dim,
28
    _remove_batch_dim,
29
    _vmap_decrement_nesting,
30
    _vmap_increment_nesting,
31
    is_batchedtensor,
32
)
33

34
in_dims_t = Union[int, Tuple]
35
out_dims_t = Union[int, Tuple[int, ...]]
36

37

38
def doesnt_support_saved_tensors_hooks(f):
39
    message = (
40
        "torch.func transforms don't yet support saved tensor hooks. "
41
        "Please open an issue with your use case."
42
    )
43

44
    @functools.wraps(f)
45
    def fn(*args, **kwargs):
46
        with torch.autograd.graph.disable_saved_tensors_hooks(message):
47
            return f(*args, **kwargs)
48
    return fn
49

50

51
# Checks that all args-to-be-batched have the same batch dim size
52
def _validate_and_get_batch_size(
53
        flat_in_dims: List[Optional[int]],
54
        flat_args: List) -> int:
55
    batch_sizes = [arg.size(in_dim) for in_dim, arg in zip(flat_in_dims, flat_args)
56
                   if in_dim is not None]
57
    if len(batch_sizes) == 0:
58
        raise ValueError('vmap: Expected at least one Tensor to vmap over')
59
    if batch_sizes and any(size != batch_sizes[0] for size in batch_sizes):
60
        raise ValueError(
61
            f'vmap: Expected all tensors to have the same size in the mapped '
62
            f'dimension, got sizes {batch_sizes} for the mapped dimension')
63
    return batch_sizes[0]
64

65

66
def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:
67
    if isinstance(batched_outputs, tuple):
68
        return len(batched_outputs)
69
    return 1
70

71
# If value is a tuple, check it has length `num_elements`.
72
# If value is not a tuple, make a tuple with `value` repeated `num_elements` times
73

74

75
def _as_tuple(value: Any, num_elements: int, error_message_lambda: Callable[[], str]) -> Tuple:
76
    if not isinstance(value, tuple):
77
        return (value,) * num_elements
78
    if len(value) != num_elements:
79
        raise ValueError(error_message_lambda())
80
    return value
81

82

83
def _process_batched_inputs(
84
    in_dims: in_dims_t, args: Tuple, func: Callable
85
) -> Tuple[int, List[Any], List[Any], TreeSpec]:
86
    if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
87
        raise ValueError(
88
            f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
89
            f'expected `in_dims` to be int or a (potentially nested) tuple '
90
            f'matching the structure of inputs, got: {type(in_dims)}.')
91
    if len(args) == 0:
92
        raise ValueError(
93
            f'vmap({_get_name(func)})(<inputs>): got no inputs. Maybe you forgot to add '
94
            f'inputs, or you are trying to vmap over a function with no inputs. '
95
            f'The latter is unsupported.')
96

97
    flat_args, args_spec = tree_flatten(args)
98
    flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec)
99
    if flat_in_dims is None:
100
        raise ValueError(
101
            f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
102
            f'in_dims is not compatible with the structure of `inputs`. '
103
            f'in_dims has structure {tree_flatten(in_dims)[1]} but inputs '
104
            f'has structure {args_spec}.')
105

106
    for i, (arg, in_dim) in enumerate(zip(flat_args, flat_in_dims)):
107
        if not isinstance(in_dim, int) and in_dim is not None:
108
            raise ValueError(
109
                f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
110
                f'Got in_dim={in_dim} for an input but in_dim must be either '
111
                f'an integer dimension or None.')
112
        if isinstance(in_dim, int) and not isinstance(arg, Tensor):
113
            raise ValueError(
114
                f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
115
                f'Got in_dim={in_dim} for an input but the input is of type '
116
                f'{type(arg)}. We cannot vmap over non-Tensor arguments, '
117
                f'please use None as the respective in_dim')
118
        if in_dim is not None and (in_dim < -arg.dim() or in_dim >= arg.dim()):
119
            raise ValueError(
120
                f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
121
                f'Got in_dim={in_dim} for some input, but that input is a Tensor '
122
                f'of dimensionality {arg.dim()} so expected in_dim to satisfy '
123
                f'-{arg.dim()} <= in_dim < {arg.dim()}.')
124
        if in_dim is not None and in_dim < 0:
125
            flat_in_dims[i] = in_dim % arg.dim()
126

127
    return _validate_and_get_batch_size(flat_in_dims, flat_args), flat_in_dims, flat_args, args_spec
128

129
# Creates BatchedTensors for every Tensor in arg that should be batched.
130
# Returns the (potentially) batched arguments and the batch_size.
131

132

133
def _create_batched_inputs(
134
        flat_in_dims: List[Any], flat_args: List[Any], vmap_level: int, args_spec) -> Tuple:
135
    # See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
136
    batched_inputs = [arg if in_dim is None else
137
                      _add_batch_dim(arg, in_dim, vmap_level)
138
                      for in_dim, arg in zip(flat_in_dims, flat_args)]
139
    return tree_unflatten(batched_inputs, args_spec)
140

141

142
def _maybe_remove_batch_dim(name, batched_output, vmap_level, batch_size, out_dim):
143

144
    if out_dim is None:
145
        if isinstance(batched_output, torch.Tensor) and is_batchedtensor(batched_output):
146
            raise ValueError(
147
                f'vmap({name}, ...): `{name}` can not return a '
148
                f'BatchedTensor when out_dim is None'
149
            )
150
        return batched_output
151

152
    # out_dim is non None
153
    if not isinstance(batched_output, torch.Tensor):
154
        raise ValueError(f'vmap({name}, ...): `{name}` must only return '
155
                         f'Tensors, got type {type(batched_output)}. '
156
                         'Did you mean to set out_dim= to None for output?')
157

158
    return _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim)
159

160

161
# Undos the batching (and any batch dimensions) associated with the `vmap_level`.
162
def _unwrap_batched(
163
        batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
164
        out_dims: out_dims_t,
165
        vmap_level: int, batch_size: int, func: Callable) -> Tuple:
166
    flat_batched_outputs, output_spec = tree_flatten(batched_outputs)
167

168
    def incompatible_error():
169
        raise ValueError(
170
            f'vmap({_get_name(func)}, ..., out_dims={out_dims})(<inputs>): '
171
            f'out_dims is not compatible with the structure of `outputs`. '
172
            f'out_dims has structure {tree_flatten(out_dims)[1]} but outputs '
173
            f'has structure {output_spec}.')
174

175
    if isinstance(batched_outputs, torch.Tensor):
176
        # Some weird edge case requires us to spell out the following
177
        # see test_out_dims_edge_case
178
        if isinstance(out_dims, int):
179
            flat_out_dims = [out_dims]
180
        elif isinstance(out_dims, tuple) and len(out_dims) == 1:
181
            flat_out_dims = out_dims
182
        elif out_dims is None:
183
            flat_out_dims = [out_dims]
184
        else:
185
            incompatible_error()
186
    else:
187
        flat_out_dims = _broadcast_to_and_flatten(out_dims, output_spec)
188
        if flat_out_dims is None:
189
            incompatible_error()
190

191
    flat_outputs = [
192
        _maybe_remove_batch_dim(_get_name(func), batched_output, vmap_level, batch_size, out_dim)
193
        for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims)
194
    ]
195
    return tree_unflatten(flat_outputs, output_spec)
196

197

198
def _check_int_or_none(x, func, out_dims):
199
    if isinstance(x, int):
200
        return
201
    if x is None:
202
        return
203
    raise ValueError(
204
        f'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be '
205
        f'an int, None or a python collection of ints representing where in the outputs the '
206
        f'vmapped dimension should appear.')
207

208

209
def _check_out_dims_is_int_or_int_pytree(out_dims: out_dims_t, func: Callable) -> None:
210
    if isinstance(out_dims, int):
211
        return
212
    tree_map_(partial(_check_int_or_none, func=func, out_dims=out_dims), out_dims)
213

214

215
def _get_name(func: Callable):
216
    if hasattr(func, '__name__'):
217
        return func.__name__
218

219
    # Not all callables have __name__, in fact, only static functions/methods do.
220
    # A callable created via functools.partial or an nn.Module, to name some
221
    # examples, don't have a __name__.
222
    return repr(func)
223

224

225
DECOMPOSITIONS_LOADED = False
226
DECOMPOSITIONS_LOCK = threading.Lock()
227
VMAP_DECOMPOSITIONS_LIB = None
228

229
# torch.package, Python 3.11, and torch.jit-less environments are unhappy with
230
# decompositions. Only load them when needed if possible.
231
def lazy_load_decompositions():
232
    global DECOMPOSITIONS_LOADED
233
    if DECOMPOSITIONS_LOADED:
234
        return
235

236
    with DECOMPOSITIONS_LOCK:
237
        if DECOMPOSITIONS_LOADED:
238
            return
239

240
        if not (os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__):
241
            DECOMPOSITIONS_LOADED = True
242
            return
243

244
        # use an alternate way to register an operator into the decomposition table
245
        # _register_jit_decomposition doesn't work for some operators, e.g. addr,
246
        #  because the Tensor types generated cannot be unioned by torchscript
247
        # decomp should be type OpOverload
248
        global VMAP_DECOMPOSITIONS_LIB
249
        VMAP_DECOMPOSITIONS_LIB = torch.library.Library("aten", "IMPL", "FuncTorchBatched")
250

251
        from torch._decomp import decomposition_table
252

253
        def _register_python_decomposition_vmap(decomp):
254
            if decomp in decomposition_table:
255
                VMAP_DECOMPOSITIONS_LIB.impl(decomp, decomposition_table[decomp])
256
            else:
257
                raise RuntimeError(f"could not find decomposition for {decomp}")
258

259
        _register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default)
260
        _register_python_decomposition_vmap(torch.ops.aten.smooth_l1_loss_backward.default)
261
        _register_python_decomposition_vmap(torch.ops.aten.huber_loss_backward.default)
262
        _register_python_decomposition_vmap(torch.ops.aten.nll_loss_forward.default)
263
        _register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_forward.default)
264
        _register_python_decomposition_vmap(torch.ops.aten.nll_loss_backward.default)
265
        _register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_backward.default)
266
        _register_python_decomposition_vmap(torch.ops.aten.addr.default)
267

268
        DECOMPOSITIONS_LOADED = True
269

270
def vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs):
271
    lazy_load_decompositions()
272
    _check_out_dims_is_int_or_int_pytree(out_dims, func)
273
    batch_size, flat_in_dims, flat_args, args_spec = _process_batched_inputs(in_dims, args, func)
274

275
    if chunk_size is not None:
276
        chunks_flat_args = _get_chunked_inputs(flat_args, flat_in_dims, batch_size, chunk_size)
277
        return _chunked_vmap(func, flat_in_dims, chunks_flat_args,
278
                             args_spec, out_dims, randomness, **kwargs)
279

280
    # If chunk_size is not specified.
281
    return _flat_vmap(
282
        func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
283
    )
284

285
def get_chunk_sizes(total_elems, chunk_size):
286
    n_chunks = n_chunks = total_elems // chunk_size
287
    chunk_sizes = [chunk_size] * n_chunks
288
    # remainder chunk
289
    remainder = total_elems % chunk_size
290
    if remainder != 0:
291
        chunk_sizes.append(remainder)
292
    return chunk_sizes
293

294
def _get_chunked_inputs(flat_args, flat_in_dims, batch_size, chunk_size):
295
    split_idxs = (batch_size,)
296
    if chunk_size is not None:
297
        chunk_sizes = get_chunk_sizes(batch_size, chunk_size)
298
        split_idxs = tuple(itertools.accumulate(chunk_sizes))
299

300
    flat_args_chunks = tuple(
301
        t.tensor_split(split_idxs, dim=in_dim) if in_dim is not None else [t, ] * len(split_idxs)
302
        for t, in_dim in zip(flat_args, flat_in_dims)
303
    )
304

305
    # transpose chunk dim and flatten structure
306
    # chunks_flat_args is a list of flatten args
307
    chunks_flat_args = zip(*flat_args_chunks)
308
    return chunks_flat_args
309

310

311
def _flatten_chunks_output(chunks_output_):
312
    # chunks_output is a list of chunked outputs
313
    # flatten chunked outputs:
314
    flat_chunks_output = []
315
    arg_spec = None
316
    for output in chunks_output_:
317
        flat_output, arg_specs = tree_flatten(output)
318
        flat_chunks_output.append(flat_output)
319
        if arg_spec is None:
320
            arg_spec = arg_specs
321

322
    # transpose chunk dim and flatten structure
323
    # flat_output_chunks is flat list of chunks
324
    flat_output_chunks = list(zip(*flat_chunks_output))
325
    return flat_output_chunks, arg_spec
326

327

328
def _concat_chunked_outputs(out_dims, arg_spec, flat_output_chunks):
329
    # concat chunks on out_dim
330
    flat_out_dims = _broadcast_to_and_flatten(out_dims, arg_spec)
331
    assert len(flat_out_dims) == len(flat_output_chunks)
332
    flat_output = []
333
    for idx, out_dim in enumerate(flat_out_dims):
334
        flat_output.append(torch.cat(flat_output_chunks[idx], dim=out_dim))
335
        # release tensors
336
        flat_output_chunks[idx] = None
337

338
    return flat_output
339

340

341
# Applies vmap on chunked_input and returns concatenated output over the chunks.
342
def _chunked_vmap(func, flat_in_dims, chunks_flat_args, args_spec, out_dims, randomness, **kwargs):
343

344
    chunks_output = []
345
    rs = torch.get_rng_state() if randomness == "same" else None
346
    for flat_args in chunks_flat_args:
347
        batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args)
348

349
        # The way we compute split the input in `_get_chunked_inputs`,
350
        # we may get a tensor with `0` batch-size. We skip any computation
351
        # in that case.
352
        # Eg.
353
        # >>> chunk_size = 1
354
        # >>> batch_size = 6
355
        # >>> t = torch.zeros(batch_size, 1)
356
        # >>> t.tensor_split([1, 2, 3, 4, 5, 6])
357
        # (tensor([[0.]]), tensor([[0.]]), tensor([[0.]]), tensor([[0.]]),
358
        #  tensor([[0.]]), tensor([[0.]]), tensor([], size=(0, 1)))
359
        if batch_size == 0:
360
            continue
361

362
        if rs is not None:
363
            torch.set_rng_state(rs)
364
        chunks_output.append(
365
            _flat_vmap(
366
                func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
367
            )
368
        )
369

370
    flat_output_chunks, arg_spec = _flatten_chunks_output(chunks_output)
371

372
    # chunked output tensors are held by both `flat_output_chunks` and `chunks_output`.
373
    # eagerly remove the reference from `chunks_output`.
374
    del chunks_output
375

376
    # concat chunks on out_dim
377
    flat_output = _concat_chunked_outputs(out_dims, arg_spec, flat_output_chunks)
378

379
    # finally unflatten the output
380
    return tree_unflatten(flat_output, arg_spec)
381

382

383
# Vmap refactored helper functions:
384
def _check_randomness_arg(randomness):
385
    if randomness not in ['error', 'different', 'same']:
386
        raise RuntimeError(f"Only allowed values for randomness are 'error', 'different', or 'same'. Got {randomness}")
387

388

389
@contextlib.contextmanager
390
def vmap_increment_nesting(batch_size, randomness):
391
    try:
392
        vmap_level = _vmap_increment_nesting(batch_size, randomness)
393
        yield vmap_level
394
    finally:
395
        _vmap_decrement_nesting()
396

397

398
@doesnt_support_saved_tensors_hooks
399
def _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs):
400

401
    with vmap_increment_nesting(batch_size, randomness) as vmap_level:
402
        batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
403
        batched_outputs = func(*batched_inputs, **kwargs)
404
        return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
405

406

407
# `restore_vmap` is a private helper function. It is vmap but has the following
408
# differences:
409
# - instead of returning outputs, it returns an (outputs, out_dims) tuple.
410
#   out_dims is a pytree of same shape as outputs and contains Optional[int]
411
#   specifying where the vmapped dimension, if it exists, is in the corresponding output.
412
# - does no validation on in_dims or inputs (vmap expects at least one Tensor to be vmapped).
413
#   restore_vmap allows for no inputs to have the vmap dimension
414
# - does no validation on outputs (vmap expects only Tensor outputs)
415
#   restore_vmap allows for return of arbitrary outputs (not just Tensors)
416
#
417
# The TL;DR is that restore_vmap is more general than vmap and has a slightly
418
# different API. The relaxations are so that we can "pause" vmap in the middle
419
# of its execution and then "restore" it later (this is what we do in
420
# the generate_vmap_rule=True implementation of autograd.Function).
421
#
422
# restore_vmap can be technically used in the implementation of vmap, but doing
423
# that refactor is a bit technically challenging because:
424
# - vmap couples the tensor-wrapping code with error checking
425
# - vmap's tensor unwrapping code is in C++; we would need to rewrite part of it
426
#   in python because it overlaps with unwrap_batched
427
@doesnt_support_saved_tensors_hooks
428
def restore_vmap(func, in_dims, batch_size, randomness):
429
    def inner(*args, **kwargs):
430
        with vmap_increment_nesting(batch_size, randomness) as vmap_level:
431
            batched_inputs = wrap_batched(args, in_dims, vmap_level)
432
            batched_outputs = func(*batched_inputs, **kwargs)
433
            return unwrap_batched(batched_outputs, vmap_level)
434
    return inner
435

436

437
def wrap_batched(args, bdims, level):
438
    flat_args, spec = tree_flatten(args)
439
    flat_bdims = _broadcast_to_and_flatten(bdims, spec)
440
    assert flat_bdims is not None
441
    result = _create_batched_inputs(flat_bdims, flat_args, level, spec)
442
    return result
443

444

445
def unwrap_batched(args, level):
446
    flat_args, spec = tree_flatten(args)
447
    if len(flat_args) == 0:
448
        return args, ()
449
    result = [torch._C._functorch._unwrap_batched(arg, level) if isinstance(arg, torch.Tensor)
450
              else (arg, None) for arg in flat_args]
451
    output, bdims = zip(*result)
452
    return tree_unflatten(output, spec), tree_unflatten(bdims, spec)
453

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.