pytorch

Форк
0
121 строка · 3.3 Кб
1
# Copyright (c) Facebook, Inc. and its affiliates.
2
# All rights reserved.
3
#
4
# This source code is licensed under the BSD-style license found in the
5
# LICENSE file in the root directory of this source tree.
6
import dis
7
import inspect
8
from dataclasses import dataclass
9
from typing import Union
10

11
from . import DimList
12

13

14
_vmap_levels = []
15

16

17
@dataclass
18
class LevelInfo:
19
    level: int
20
    alive: bool = True
21

22

23
class Dim:
24
    def __init__(self, name: str, size: Union[None, int] = None):
25
        self.name = name
26
        self._size = None
27
        self._vmap_level = None
28
        if size is not None:
29
            self.size = size
30

31
    def __del__(self):
32
        if self._vmap_level is not None:
33
            _vmap_active_levels[self._vmap_stack].alive = False  # noqa: F821
34
            while (
35
                not _vmap_levels[-1].alive
36
                and current_level() == _vmap_levels[-1].level  # noqa: F821
37
            ):
38
                _vmap_decrement_nesting()  # noqa: F821
39
                _vmap_levels.pop()
40

41
    @property
42
    def size(self):
43
        assert self.is_bound
44
        return self._size
45

46
    @size.setter
47
    def size(self, size: int):
48
        from . import DimensionBindError
49

50
        if self._size is None:
51
            self._size = size
52
            self._vmap_level = _vmap_increment_nesting(size, "same")  # noqa: F821
53
            self._vmap_stack = len(_vmap_levels)
54
            _vmap_levels.append(LevelInfo(self._vmap_level))
55

56
        elif self._size != size:
57
            raise DimensionBindError(
58
                f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}"
59
            )
60

61
    @property
62
    def is_bound(self):
63
        return self._size is not None
64

65
    def __repr__(self):
66
        return self.name
67

68

69
def extract_name(inst):
70
    assert inst.opname == "STORE_FAST" or inst.opname == "STORE_NAME"
71
    return inst.argval
72

73

74
_cache = {}
75

76

77
def dims(lists=0):
78
    frame = inspect.currentframe()
79
    assert frame is not None
80
    calling_frame = frame.f_back
81
    assert calling_frame is not None
82
    code, lasti = calling_frame.f_code, calling_frame.f_lasti
83
    key = (code, lasti)
84
    if key not in _cache:
85
        first = lasti // 2 + 1
86
        instructions = list(dis.get_instructions(calling_frame.f_code))
87
        unpack = instructions[first]
88

89
        if unpack.opname == "STORE_FAST" or unpack.opname == "STORE_NAME":
90
            # just a single dim, not a list
91
            name = unpack.argval
92
            ctor = Dim if lists == 0 else DimList
93
            _cache[key] = lambda: ctor(name=name)
94
        else:
95
            assert unpack.opname == "UNPACK_SEQUENCE"
96
            ndims = unpack.argval
97
            names = tuple(
98
                extract_name(instructions[first + 1 + i]) for i in range(ndims)
99
            )
100
            first_list = len(names) - lists
101
            _cache[key] = lambda: tuple(
102
                Dim(n) if i < first_list else DimList(name=n)
103
                for i, n in enumerate(names)
104
            )
105
    return _cache[key]()
106

107

108
def _dim_set(positional, arg):
109
    def convert(a):
110
        if isinstance(a, Dim):
111
            return a
112
        else:
113
            assert isinstance(a, int)
114
            return positional[a]
115

116
    if arg is None:
117
        return positional
118
    elif not isinstance(arg, (Dim, int)):
119
        return tuple(convert(a) for a in arg)
120
    else:
121
        return (convert(arg),)
122

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.