pytorch

Форк
0
/
batch_tensor.py 
26 строк · 668.0 Байт
1
# Copyright (c) Facebook, Inc. and its affiliates.
2
# All rights reserved.
3
#
4
# This source code is licensed under the BSD-style license found in the
5
# LICENSE file in the root directory of this source tree.
6
from contextlib import contextmanager
7

8
from torch._C._functorch import _vmap_add_layers, _vmap_remove_layers
9

10

11
_enabled = False
12

13

14
@contextmanager
15
def _enable_layers(dims):
16
    global _enabled
17
    assert not _enabled
18
    input = sorted((d._level, d.size) for d in dims if not isinstance(d, int))
19
    n = len(input)
20
    try:
21
        _vmap_add_layers(input)
22
        _enabled = True
23
        yield
24
    finally:
25
        _enabled = False
26
        _vmap_remove_layers(n)
27

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.