pytorch

Форк
0
/
arg_scope.py 
32 строки · 1.1 Кб
1

2

3

4
import contextlib
5
import copy
6
import threading
7

8
_threadlocal_scope = threading.local()
9

10

11
@contextlib.contextmanager
12
def arg_scope(single_helper_or_list, **kwargs):
13
    global _threadlocal_scope
14
    if not isinstance(single_helper_or_list, list):
15
        assert callable(single_helper_or_list), \
16
            "arg_scope is only supporting single or a list of helper functions."
17
        single_helper_or_list = [single_helper_or_list]
18
    old_scope = copy.deepcopy(get_current_scope())
19
    for helper in single_helper_or_list:
20
        assert callable(helper), \
21
            "arg_scope is only supporting a list of callable helper functions."
22
        helper_key = helper.__name__
23
        if helper_key not in old_scope:
24
            _threadlocal_scope.current_scope[helper_key] = {}
25
        _threadlocal_scope.current_scope[helper_key].update(kwargs)
26

27
    yield
28
    _threadlocal_scope.current_scope = old_scope
29

30

31
def get_current_scope():
32
    global _threadlocal_scope
33
    if not hasattr(_threadlocal_scope, "current_scope"):
34
        _threadlocal_scope.current_scope = {}
35
    return _threadlocal_scope.current_scope
36

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.