pytorch
1# Copyright (c) Facebook, Inc. and its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6from contextlib import contextmanager7
8from torch._C._functorch import _vmap_add_layers, _vmap_remove_layers9
10
11_enabled = False12
13
14@contextmanager
15def _enable_layers(dims):16global _enabled17assert not _enabled18input = sorted((d._level, d.size) for d in dims if not isinstance(d, int))19n = len(input)20try:21_vmap_add_layers(input)22_enabled = True23yield24finally:25_enabled = False26_vmap_remove_layers(n)27