pytorch
452 строки · 17.9 Кб
1# mypy: ignore-errors
2
3# Copyright (c) Facebook, Inc. and its affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8
9import torch10import contextlib11import functools12import threading13from torch import Tensor14from typing import Any, Callable, Optional, Tuple, Union, List15from torch.utils._pytree import (16tree_flatten,17tree_unflatten,18tree_map_,19_broadcast_to_and_flatten,20TreeSpec,21)
22from functools import partial23import os24import itertools25
26from torch._C._functorch import (27_add_batch_dim,28_remove_batch_dim,29_vmap_decrement_nesting,30_vmap_increment_nesting,31is_batchedtensor,32)
33
34in_dims_t = Union[int, Tuple]35out_dims_t = Union[int, Tuple[int, ...]]36
37
38def doesnt_support_saved_tensors_hooks(f):39message = (40"torch.func transforms don't yet support saved tensor hooks. "41"Please open an issue with your use case."42)43
44@functools.wraps(f)45def fn(*args, **kwargs):46with torch.autograd.graph.disable_saved_tensors_hooks(message):47return f(*args, **kwargs)48return fn49
50
51# Checks that all args-to-be-batched have the same batch dim size
52def _validate_and_get_batch_size(53flat_in_dims: List[Optional[int]],54flat_args: List) -> int:55batch_sizes = [arg.size(in_dim) for in_dim, arg in zip(flat_in_dims, flat_args)56if in_dim is not None]57if len(batch_sizes) == 0:58raise ValueError('vmap: Expected at least one Tensor to vmap over')59if batch_sizes and any(size != batch_sizes[0] for size in batch_sizes):60raise ValueError(61f'vmap: Expected all tensors to have the same size in the mapped '62f'dimension, got sizes {batch_sizes} for the mapped dimension')63return batch_sizes[0]64
65
66def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:67if isinstance(batched_outputs, tuple):68return len(batched_outputs)69return 170
71# If value is a tuple, check it has length `num_elements`.
72# If value is not a tuple, make a tuple with `value` repeated `num_elements` times
73
74
75def _as_tuple(value: Any, num_elements: int, error_message_lambda: Callable[[], str]) -> Tuple:76if not isinstance(value, tuple):77return (value,) * num_elements78if len(value) != num_elements:79raise ValueError(error_message_lambda())80return value81
82
83def _process_batched_inputs(84in_dims: in_dims_t, args: Tuple, func: Callable85) -> Tuple[int, List[Any], List[Any], TreeSpec]:86if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):87raise ValueError(88f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '89f'expected `in_dims` to be int or a (potentially nested) tuple '90f'matching the structure of inputs, got: {type(in_dims)}.')91if len(args) == 0:92raise ValueError(93f'vmap({_get_name(func)})(<inputs>): got no inputs. Maybe you forgot to add '94f'inputs, or you are trying to vmap over a function with no inputs. '95f'The latter is unsupported.')96
97flat_args, args_spec = tree_flatten(args)98flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec)99if flat_in_dims is None:100raise ValueError(101f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '102f'in_dims is not compatible with the structure of `inputs`. '103f'in_dims has structure {tree_flatten(in_dims)[1]} but inputs '104f'has structure {args_spec}.')105
106for i, (arg, in_dim) in enumerate(zip(flat_args, flat_in_dims)):107if not isinstance(in_dim, int) and in_dim is not None:108raise ValueError(109f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '110f'Got in_dim={in_dim} for an input but in_dim must be either '111f'an integer dimension or None.')112if isinstance(in_dim, int) and not isinstance(arg, Tensor):113raise ValueError(114f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '115f'Got in_dim={in_dim} for an input but the input is of type '116f'{type(arg)}. We cannot vmap over non-Tensor arguments, '117f'please use None as the respective in_dim')118if in_dim is not None and (in_dim < -arg.dim() or in_dim >= arg.dim()):119raise ValueError(120f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '121f'Got in_dim={in_dim} for some input, but that input is a Tensor '122f'of dimensionality {arg.dim()} so expected in_dim to satisfy '123f'-{arg.dim()} <= in_dim < {arg.dim()}.')124if in_dim is not None and in_dim < 0:125flat_in_dims[i] = in_dim % arg.dim()126
127return _validate_and_get_batch_size(flat_in_dims, flat_args), flat_in_dims, flat_args, args_spec128
129# Creates BatchedTensors for every Tensor in arg that should be batched.
130# Returns the (potentially) batched arguments and the batch_size.
131
132
133def _create_batched_inputs(134flat_in_dims: List[Any], flat_args: List[Any], vmap_level: int, args_spec) -> Tuple:135# See NOTE [Ignored _remove_batch_dim, _add_batch_dim]136batched_inputs = [arg if in_dim is None else137_add_batch_dim(arg, in_dim, vmap_level)138for in_dim, arg in zip(flat_in_dims, flat_args)]139return tree_unflatten(batched_inputs, args_spec)140
141
142def _maybe_remove_batch_dim(name, batched_output, vmap_level, batch_size, out_dim):143
144if out_dim is None:145if isinstance(batched_output, torch.Tensor) and is_batchedtensor(batched_output):146raise ValueError(147f'vmap({name}, ...): `{name}` can not return a '148f'BatchedTensor when out_dim is None'149)150return batched_output151
152# out_dim is non None153if not isinstance(batched_output, torch.Tensor):154raise ValueError(f'vmap({name}, ...): `{name}` must only return '155f'Tensors, got type {type(batched_output)}. '156'Did you mean to set out_dim= to None for output?')157
158return _remove_batch_dim(batched_output, vmap_level, batch_size, out_dim)159
160
161# Undos the batching (and any batch dimensions) associated with the `vmap_level`.
162def _unwrap_batched(163batched_outputs: Union[Tensor, Tuple[Tensor, ...]],164out_dims: out_dims_t,165vmap_level: int, batch_size: int, func: Callable) -> Tuple:166flat_batched_outputs, output_spec = tree_flatten(batched_outputs)167
168def incompatible_error():169raise ValueError(170f'vmap({_get_name(func)}, ..., out_dims={out_dims})(<inputs>): '171f'out_dims is not compatible with the structure of `outputs`. '172f'out_dims has structure {tree_flatten(out_dims)[1]} but outputs '173f'has structure {output_spec}.')174
175if isinstance(batched_outputs, torch.Tensor):176# Some weird edge case requires us to spell out the following177# see test_out_dims_edge_case178if isinstance(out_dims, int):179flat_out_dims = [out_dims]180elif isinstance(out_dims, tuple) and len(out_dims) == 1:181flat_out_dims = out_dims182elif out_dims is None:183flat_out_dims = [out_dims]184else:185incompatible_error()186else:187flat_out_dims = _broadcast_to_and_flatten(out_dims, output_spec)188if flat_out_dims is None:189incompatible_error()190
191flat_outputs = [192_maybe_remove_batch_dim(_get_name(func), batched_output, vmap_level, batch_size, out_dim)193for batched_output, out_dim in zip(flat_batched_outputs, flat_out_dims)194]195return tree_unflatten(flat_outputs, output_spec)196
197
198def _check_int_or_none(x, func, out_dims):199if isinstance(x, int):200return201if x is None:202return203raise ValueError(204f'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be '205f'an int, None or a python collection of ints representing where in the outputs the '206f'vmapped dimension should appear.')207
208
209def _check_out_dims_is_int_or_int_pytree(out_dims: out_dims_t, func: Callable) -> None:210if isinstance(out_dims, int):211return212tree_map_(partial(_check_int_or_none, func=func, out_dims=out_dims), out_dims)213
214
215def _get_name(func: Callable):216if hasattr(func, '__name__'):217return func.__name__218
219# Not all callables have __name__, in fact, only static functions/methods do.220# A callable created via functools.partial or an nn.Module, to name some221# examples, don't have a __name__.222return repr(func)223
224
225DECOMPOSITIONS_LOADED = False226DECOMPOSITIONS_LOCK = threading.Lock()227VMAP_DECOMPOSITIONS_LIB = None228
229# torch.package, Python 3.11, and torch.jit-less environments are unhappy with
230# decompositions. Only load them when needed if possible.
231def lazy_load_decompositions():232global DECOMPOSITIONS_LOADED233if DECOMPOSITIONS_LOADED:234return235
236with DECOMPOSITIONS_LOCK:237if DECOMPOSITIONS_LOADED:238return239
240if not (os.environ.get("PYTORCH_JIT", "1") == "1" and __debug__):241DECOMPOSITIONS_LOADED = True242return243
244# use an alternate way to register an operator into the decomposition table245# _register_jit_decomposition doesn't work for some operators, e.g. addr,246# because the Tensor types generated cannot be unioned by torchscript247# decomp should be type OpOverload248global VMAP_DECOMPOSITIONS_LIB249VMAP_DECOMPOSITIONS_LIB = torch.library.Library("aten", "IMPL", "FuncTorchBatched")250
251from torch._decomp import decomposition_table252
253def _register_python_decomposition_vmap(decomp):254if decomp in decomposition_table:255VMAP_DECOMPOSITIONS_LIB.impl(decomp, decomposition_table[decomp])256else:257raise RuntimeError(f"could not find decomposition for {decomp}")258
259_register_python_decomposition_vmap(torch.ops.aten.mse_loss_backward.default)260_register_python_decomposition_vmap(torch.ops.aten.smooth_l1_loss_backward.default)261_register_python_decomposition_vmap(torch.ops.aten.huber_loss_backward.default)262_register_python_decomposition_vmap(torch.ops.aten.nll_loss_forward.default)263_register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_forward.default)264_register_python_decomposition_vmap(torch.ops.aten.nll_loss_backward.default)265_register_python_decomposition_vmap(torch.ops.aten.nll_loss2d_backward.default)266_register_python_decomposition_vmap(torch.ops.aten.addr.default)267
268DECOMPOSITIONS_LOADED = True269
270def vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs):271lazy_load_decompositions()272_check_out_dims_is_int_or_int_pytree(out_dims, func)273batch_size, flat_in_dims, flat_args, args_spec = _process_batched_inputs(in_dims, args, func)274
275if chunk_size is not None:276chunks_flat_args = _get_chunked_inputs(flat_args, flat_in_dims, batch_size, chunk_size)277return _chunked_vmap(func, flat_in_dims, chunks_flat_args,278args_spec, out_dims, randomness, **kwargs)279
280# If chunk_size is not specified.281return _flat_vmap(282func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs283)284
285def get_chunk_sizes(total_elems, chunk_size):286n_chunks = n_chunks = total_elems // chunk_size287chunk_sizes = [chunk_size] * n_chunks288# remainder chunk289remainder = total_elems % chunk_size290if remainder != 0:291chunk_sizes.append(remainder)292return chunk_sizes293
294def _get_chunked_inputs(flat_args, flat_in_dims, batch_size, chunk_size):295split_idxs = (batch_size,)296if chunk_size is not None:297chunk_sizes = get_chunk_sizes(batch_size, chunk_size)298split_idxs = tuple(itertools.accumulate(chunk_sizes))299
300flat_args_chunks = tuple(301t.tensor_split(split_idxs, dim=in_dim) if in_dim is not None else [t, ] * len(split_idxs)302for t, in_dim in zip(flat_args, flat_in_dims)303)304
305# transpose chunk dim and flatten structure306# chunks_flat_args is a list of flatten args307chunks_flat_args = zip(*flat_args_chunks)308return chunks_flat_args309
310
311def _flatten_chunks_output(chunks_output_):312# chunks_output is a list of chunked outputs313# flatten chunked outputs:314flat_chunks_output = []315arg_spec = None316for output in chunks_output_:317flat_output, arg_specs = tree_flatten(output)318flat_chunks_output.append(flat_output)319if arg_spec is None:320arg_spec = arg_specs321
322# transpose chunk dim and flatten structure323# flat_output_chunks is flat list of chunks324flat_output_chunks = list(zip(*flat_chunks_output))325return flat_output_chunks, arg_spec326
327
328def _concat_chunked_outputs(out_dims, arg_spec, flat_output_chunks):329# concat chunks on out_dim330flat_out_dims = _broadcast_to_and_flatten(out_dims, arg_spec)331assert len(flat_out_dims) == len(flat_output_chunks)332flat_output = []333for idx, out_dim in enumerate(flat_out_dims):334flat_output.append(torch.cat(flat_output_chunks[idx], dim=out_dim))335# release tensors336flat_output_chunks[idx] = None337
338return flat_output339
340
341# Applies vmap on chunked_input and returns concatenated output over the chunks.
342def _chunked_vmap(func, flat_in_dims, chunks_flat_args, args_spec, out_dims, randomness, **kwargs):343
344chunks_output = []345rs = torch.get_rng_state() if randomness == "same" else None346for flat_args in chunks_flat_args:347batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args)348
349# The way we compute split the input in `_get_chunked_inputs`,350# we may get a tensor with `0` batch-size. We skip any computation351# in that case.352# Eg.353# >>> chunk_size = 1354# >>> batch_size = 6355# >>> t = torch.zeros(batch_size, 1)356# >>> t.tensor_split([1, 2, 3, 4, 5, 6])357# (tensor([[0.]]), tensor([[0.]]), tensor([[0.]]), tensor([[0.]]),358# tensor([[0.]]), tensor([[0.]]), tensor([], size=(0, 1)))359if batch_size == 0:360continue361
362if rs is not None:363torch.set_rng_state(rs)364chunks_output.append(365_flat_vmap(366func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs367)368)369
370flat_output_chunks, arg_spec = _flatten_chunks_output(chunks_output)371
372# chunked output tensors are held by both `flat_output_chunks` and `chunks_output`.373# eagerly remove the reference from `chunks_output`.374del chunks_output375
376# concat chunks on out_dim377flat_output = _concat_chunked_outputs(out_dims, arg_spec, flat_output_chunks)378
379# finally unflatten the output380return tree_unflatten(flat_output, arg_spec)381
382
383# Vmap refactored helper functions:
384def _check_randomness_arg(randomness):385if randomness not in ['error', 'different', 'same']:386raise RuntimeError(f"Only allowed values for randomness are 'error', 'different', or 'same'. Got {randomness}")387
388
389@contextlib.contextmanager390def vmap_increment_nesting(batch_size, randomness):391try:392vmap_level = _vmap_increment_nesting(batch_size, randomness)393yield vmap_level394finally:395_vmap_decrement_nesting()396
397
398@doesnt_support_saved_tensors_hooks
399def _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs):400
401with vmap_increment_nesting(batch_size, randomness) as vmap_level:402batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)403batched_outputs = func(*batched_inputs, **kwargs)404return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)405
406
407# `restore_vmap` is a private helper function. It is vmap but has the following
408# differences:
409# - instead of returning outputs, it returns an (outputs, out_dims) tuple.
410# out_dims is a pytree of same shape as outputs and contains Optional[int]
411# specifying where the vmapped dimension, if it exists, is in the corresponding output.
412# - does no validation on in_dims or inputs (vmap expects at least one Tensor to be vmapped).
413# restore_vmap allows for no inputs to have the vmap dimension
414# - does no validation on outputs (vmap expects only Tensor outputs)
415# restore_vmap allows for return of arbitrary outputs (not just Tensors)
416#
417# The TL;DR is that restore_vmap is more general than vmap and has a slightly
418# different API. The relaxations are so that we can "pause" vmap in the middle
419# of its execution and then "restore" it later (this is what we do in
420# the generate_vmap_rule=True implementation of autograd.Function).
421#
422# restore_vmap can be technically used in the implementation of vmap, but doing
423# that refactor is a bit technically challenging because:
424# - vmap couples the tensor-wrapping code with error checking
425# - vmap's tensor unwrapping code is in C++; we would need to rewrite part of it
426# in python because it overlaps with unwrap_batched
427@doesnt_support_saved_tensors_hooks
428def restore_vmap(func, in_dims, batch_size, randomness):429def inner(*args, **kwargs):430with vmap_increment_nesting(batch_size, randomness) as vmap_level:431batched_inputs = wrap_batched(args, in_dims, vmap_level)432batched_outputs = func(*batched_inputs, **kwargs)433return unwrap_batched(batched_outputs, vmap_level)434return inner435
436
437def wrap_batched(args, bdims, level):438flat_args, spec = tree_flatten(args)439flat_bdims = _broadcast_to_and_flatten(bdims, spec)440assert flat_bdims is not None441result = _create_batched_inputs(flat_bdims, flat_args, level, spec)442return result443
444
445def unwrap_batched(args, level):446flat_args, spec = tree_flatten(args)447if len(flat_args) == 0:448return args, ()449result = [torch._C._functorch._unwrap_batched(arg, level) if isinstance(arg, torch.Tensor)450else (arg, None) for arg in flat_args]451output, bdims = zip(*result)452return tree_unflatten(output, spec), tree_unflatten(bdims, spec)453