pytorch

Форк
0
/
_namedtensor_internals.py 
157 строк · 5.1 Кб
1
from collections import OrderedDict
2

3
"""
4
This file contains helper functions that implement experimental functionality
5
for named tensors in python. All of these are experimental, unstable, and
6
subject to change or deletion.
7
"""
8

9

10
def check_serializing_named_tensor(tensor):
11
    if tensor.has_names():
12
        raise RuntimeError(
13
            "NYI: Named tensors don't support serialization. Please drop "
14
            "names via `tensor = tensor.rename(None)` before serialization."
15
        )
16

17

18
def build_dim_map(tensor):
19
    """Returns a map of { dim: dim_name } where dim is a name if the dim is named
20
    and the dim index otherwise."""
21
    return OrderedDict(
22
        [(idx if name is None else name, name) for idx, name in enumerate(tensor.names)]
23
    )
24

25

26
def unzip_namedshape(namedshape):
27
    if isinstance(namedshape, OrderedDict):
28
        namedshape = namedshape.items()
29
    if not hasattr(namedshape, "__iter__") and not isinstance(namedshape, tuple):
30
        raise RuntimeError(
31
            f"Expected namedshape to be OrderedDict or iterable of tuples, got: {type(namedshape)}"
32
        )
33
    if len(namedshape) == 0:
34
        raise RuntimeError("Expected namedshape to non-empty.")
35
    return zip(*namedshape)
36

37

38
def namer_api_name(inplace):
39
    if inplace:
40
        return "rename_"
41
    else:
42
        return "rename"
43

44

45
def is_ellipsis(item):
46
    return item == Ellipsis or item == "..."
47

48

49
def single_ellipsis_index(names, fn_name):
50
    ellipsis_indices = [i for i, name in enumerate(names) if is_ellipsis(name)]
51
    if len(ellipsis_indices) >= 2:
52
        raise RuntimeError(
53
            f"{fn_name}: More than one Ellipsis ('...') found in names ("
54
            f"{names}). This function supports up to one Ellipsis."
55
        )
56
    if len(ellipsis_indices) == 1:
57
        return ellipsis_indices[0]
58
    return None
59

60

61
def expand_single_ellipsis(numel_pre_glob, numel_post_glob, names):
62
    return names[numel_pre_glob : len(names) - numel_post_glob]
63

64

65
def replace_ellipsis_by_position(ellipsis_idx, names, tensor_names):
66
    globbed_names = expand_single_ellipsis(
67
        ellipsis_idx, len(names) - ellipsis_idx - 1, tensor_names
68
    )
69
    return names[:ellipsis_idx] + globbed_names + names[ellipsis_idx + 1 :]
70

71

72
def resolve_ellipsis(names, tensor_names, fn_name):
73
    """
74
    Expands ... inside `names` to be equal to a list of names from `tensor_names`.
75
    """
76
    ellipsis_idx = single_ellipsis_index(names, fn_name)
77
    if ellipsis_idx is None:
78
        return names
79
    return replace_ellipsis_by_position(ellipsis_idx, names, tensor_names)
80

81

82
def update_names_with_list(tensor, names, inplace):
83
    # Special case for tensor.rename(None)
84
    if len(names) == 1 and names[0] is None:
85
        return tensor._update_names(None, inplace)
86

87
    return tensor._update_names(
88
        resolve_ellipsis(names, tensor.names, namer_api_name(inplace)), inplace
89
    )
90

91

92
def update_names_with_mapping(tensor, rename_map, inplace):
93
    dim_map = build_dim_map(tensor)
94
    for old_dim in rename_map.keys():
95
        new_dim = rename_map[old_dim]
96
        if old_dim in dim_map.keys():
97
            dim_map[old_dim] = new_dim
98
        else:
99
            raise RuntimeError(
100
                f"{namer_api_name(inplace)}: Tried to rename dim '{old_dim}' to dim "
101
                f"{new_dim} in Tensor[{tensor.names}] but dim '{old_dim}' does not exist"
102
            )
103
    return tensor._update_names(tuple(dim_map.values()), inplace)
104

105

106
def update_names(tensor, names, rename_map, inplace):
107
    """There are two usages:
108

109
    tensor.rename(*names) returns a view on tensor with named dims `names`.
110
    `names` must be of length `tensor.dim()`; otherwise, if '...' is in `names`,
111
    then it is expanded greedily to be equal to the corresponding names from
112
    `tensor.names`.
113

114
    For example,
115
    ```
116
    >>> # xdoctest: +SKIP
117
    >>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
118
    >>> x.rename('...', 'height', 'width').names
119
    ('N', 'C', 'height', 'width')
120

121
    >>> # xdoctest: +SKIP
122
    >>> x.rename('batch', '...', 'width').names
123
    ('batch', 'C', 'H', 'width')
124

125
    ```
126

127
    tensor.rename(**rename_map) returns a view on tensor that has rename dims
128
        as specified in the mapping `rename_map`.
129

130
    For example,
131
    ```
132
    >>> # xdoctest: +SKIP
133
    >>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
134
    >>> x.rename(W='width', H='height').names
135
    ('N', 'C', 'height', 'width')
136

137
    ```
138

139
    Finally, tensor.rename has an in-place version called tensor.rename_.
140
    """
141
    has_names = len(names) > 0
142
    has_rename_pairs = bool(rename_map)
143
    if has_names and has_rename_pairs:
144
        raise RuntimeError(
145
            f"{namer_api_name(inplace)}: This function takes either positional "
146
            f"args or keyword args, but not both. Use tensor.{namer_api_name(inplace)}(*names) "
147
            f"to name dims and tensor.{namer_api_name(inplace)}(**rename_map) to rename "
148
            "dims."
149
        )
150

151
    # Special case for tensor.rename(*[]), which is valid for a 0 dim tensor.
152
    if not has_names and not has_rename_pairs:
153
        return update_names_with_list(tensor, names, inplace)
154

155
    if has_names:
156
        return update_names_with_list(tensor, names, inplace)
157
    return update_names_with_mapping(tensor, rename_map, inplace)
158

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.