pytorch

Форк
0
28 строк · 617.0 Байт
1
import contextlib
2
from typing import Callable, List, TYPE_CHECKING
3

4
if TYPE_CHECKING:
5
    import torch
6

7
# Executed in the order they're registered
8
INTERMEDIATE_HOOKS: List[Callable[[str, "torch.Tensor"], None]] = []
9

10

11
@contextlib.contextmanager
12
def intermediate_hook(fn):
13
    INTERMEDIATE_HOOKS.append(fn)
14
    try:
15
        yield
16
    finally:
17
        INTERMEDIATE_HOOKS.pop()
18

19

20
def run_intermediate_hooks(name, val):
21
    global INTERMEDIATE_HOOKS
22
    hooks = INTERMEDIATE_HOOKS
23
    INTERMEDIATE_HOOKS = []
24
    try:
25
        for hook in hooks:
26
            hook(name, val)
27
    finally:
28
        INTERMEDIATE_HOOKS = hooks
29

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.