1
from collections import OrderedDict
4
This file contains helper functions that implement experimental functionality
5
for named tensors in python. All of these are experimental, unstable, and
6
subject to change or deletion.
10
def check_serializing_named_tensor(tensor):
11
if tensor.has_names():
13
"NYI: Named tensors don't support serialization. Please drop "
14
"names via `tensor = tensor.rename(None)` before serialization."
18
def build_dim_map(tensor):
19
"""Returns a map of { dim: dim_name } where dim is a name if the dim is named
20
and the dim index otherwise."""
22
[(idx if name is None else name, name) for idx, name in enumerate(tensor.names)]
26
def unzip_namedshape(namedshape):
27
if isinstance(namedshape, OrderedDict):
28
namedshape = namedshape.items()
29
if not hasattr(namedshape, "__iter__") and not isinstance(namedshape, tuple):
31
f"Expected namedshape to be OrderedDict or iterable of tuples, got: {type(namedshape)}"
33
if len(namedshape) == 0:
34
raise RuntimeError("Expected namedshape to non-empty.")
35
return zip(*namedshape)
38
def namer_api_name(inplace):
46
return item == Ellipsis or item == "..."
49
def single_ellipsis_index(names, fn_name):
50
ellipsis_indices = [i for i, name in enumerate(names) if is_ellipsis(name)]
51
if len(ellipsis_indices) >= 2:
53
f"{fn_name}: More than one Ellipsis ('...') found in names ("
54
f"{names}). This function supports up to one Ellipsis."
56
if len(ellipsis_indices) == 1:
57
return ellipsis_indices[0]
61
def expand_single_ellipsis(numel_pre_glob, numel_post_glob, names):
62
return names[numel_pre_glob : len(names) - numel_post_glob]
65
def replace_ellipsis_by_position(ellipsis_idx, names, tensor_names):
66
globbed_names = expand_single_ellipsis(
67
ellipsis_idx, len(names) - ellipsis_idx - 1, tensor_names
69
return names[:ellipsis_idx] + globbed_names + names[ellipsis_idx + 1 :]
72
def resolve_ellipsis(names, tensor_names, fn_name):
74
Expands ... inside `names` to be equal to a list of names from `tensor_names`.
76
ellipsis_idx = single_ellipsis_index(names, fn_name)
77
if ellipsis_idx is None:
79
return replace_ellipsis_by_position(ellipsis_idx, names, tensor_names)
82
def update_names_with_list(tensor, names, inplace):
83
# Special case for tensor.rename(None)
84
if len(names) == 1 and names[0] is None:
85
return tensor._update_names(None, inplace)
87
return tensor._update_names(
88
resolve_ellipsis(names, tensor.names, namer_api_name(inplace)), inplace
92
def update_names_with_mapping(tensor, rename_map, inplace):
93
dim_map = build_dim_map(tensor)
94
for old_dim in rename_map.keys():
95
new_dim = rename_map[old_dim]
96
if old_dim in dim_map.keys():
97
dim_map[old_dim] = new_dim
100
f"{namer_api_name(inplace)}: Tried to rename dim '{old_dim}' to dim "
101
f"{new_dim} in Tensor[{tensor.names}] but dim '{old_dim}' does not exist"
103
return tensor._update_names(tuple(dim_map.values()), inplace)
106
def update_names(tensor, names, rename_map, inplace):
107
"""There are two usages:
109
tensor.rename(*names) returns a view on tensor with named dims `names`.
110
`names` must be of length `tensor.dim()`; otherwise, if '...' is in `names`,
111
then it is expanded greedily to be equal to the corresponding names from
116
>>> # xdoctest: +SKIP
117
>>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
118
>>> x.rename('...', 'height', 'width').names
119
('N', 'C', 'height', 'width')
121
>>> # xdoctest: +SKIP
122
>>> x.rename('batch', '...', 'width').names
123
('batch', 'C', 'H', 'width')
127
tensor.rename(**rename_map) returns a view on tensor that has rename dims
128
as specified in the mapping `rename_map`.
132
>>> # xdoctest: +SKIP
133
>>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
134
>>> x.rename(W='width', H='height').names
135
('N', 'C', 'height', 'width')
139
Finally, tensor.rename has an in-place version called tensor.rename_.
141
has_names = len(names) > 0
142
has_rename_pairs = bool(rename_map)
143
if has_names and has_rename_pairs:
145
f"{namer_api_name(inplace)}: This function takes either positional "
146
f"args or keyword args, but not both. Use tensor.{namer_api_name(inplace)}(*names) "
147
f"to name dims and tensor.{namer_api_name(inplace)}(**rename_map) to rename "
151
# Special case for tensor.rename(*[]), which is valid for a 0 dim tensor.
152
if not has_names and not has_rename_pairs:
153
return update_names_with_list(tensor, names, inplace)
156
return update_names_with_list(tensor, names, inplace)
157
return update_names_with_mapping(tensor, rename_map, inplace)