pytorch

Форк
0
/
_cxx_pytree.py 
1007 строк · 34.3 Кб
1
"""
2
Contains utility functions for working with nested python data structures.
3

4
A *pytree* is Python nested data structure. It is a tree in the sense that
5
nodes are Python collections (e.g., list, tuple, dict) and the leaves are
6
Python values. Furthermore, a pytree should not contain reference cycles.
7

8
pytrees are useful for working with nested collections of Tensors. For example,
9
one can use `tree_map` to map a function over all Tensors inside some nested
10
collection of Tensors and `tree_leaves` to get a flat list of all Tensors
11
inside some nested collection. pytrees are helpful for implementing nested
12
collection support for PyTorch APIs.
13
"""
14

15
import functools
16
import sys
17
import types
18
from typing import (
19
    Any,
20
    Callable,
21
    Iterable,
22
    List,
23
    Optional,
24
    overload,
25
    Tuple,
26
    Type,
27
    TypeVar,
28
    Union,
29
)
30
from typing_extensions import deprecated
31

32
import optree
33
from optree import PyTreeSpec  # direct import for type annotations
34

35
import torch.utils._pytree as _pytree
36
from torch.utils._pytree import KeyEntry
37

38

39
__all__ = [
40
    "PyTree",
41
    "Context",
42
    "FlattenFunc",
43
    "UnflattenFunc",
44
    "DumpableContext",
45
    "ToDumpableContextFn",
46
    "FromDumpableContextFn",
47
    "TreeSpec",
48
    "LeafSpec",
49
    "keystr",
50
    "key_get",
51
    "register_pytree_node",
52
    "tree_flatten",
53
    "tree_flatten_with_path",
54
    "tree_unflatten",
55
    "tree_iter",
56
    "tree_leaves",
57
    "tree_leaves_with_path",
58
    "tree_structure",
59
    "tree_map",
60
    "tree_map_with_path",
61
    "tree_map_",
62
    "tree_map_only",
63
    "tree_map_only_",
64
    "tree_all",
65
    "tree_any",
66
    "tree_all_only",
67
    "tree_any_only",
68
    "treespec_dumps",
69
    "treespec_loads",
70
    "treespec_pprint",
71
]
72

73

74
T = TypeVar("T")
75
S = TypeVar("S")
76
U = TypeVar("U")
77
R = TypeVar("R")
78

79

80
Context = Any
81
PyTree = Any
82
TreeSpec = PyTreeSpec
83
FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]]
84
UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
85
OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree]
86
DumpableContext = Any  # Any json dumpable text
87
ToDumpableContextFn = Callable[[Context], DumpableContext]
88
FromDumpableContextFn = Callable[[DumpableContext], Context]
89
KeyPath = Tuple[KeyEntry, ...]
90
FlattenWithKeysFunc = Callable[[PyTree], Tuple[List[Tuple[KeyEntry, Any]], Any]]
91

92

93
def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
94
    @functools.wraps(func)
95
    def wrapped(*args: Any, **kwargs: Any) -> Any:
96
        return func(*reversed(args), **kwargs)
97

98
    return wrapped
99

100

101
def register_pytree_node(
102
    cls: Type[Any],
103
    flatten_fn: FlattenFunc,
104
    unflatten_fn: UnflattenFunc,
105
    *,
106
    serialized_type_name: Optional[str] = None,
107
    to_dumpable_context: Optional[ToDumpableContextFn] = None,
108
    from_dumpable_context: Optional[FromDumpableContextFn] = None,
109
    flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
110
) -> None:
111
    """Register a container-like type as pytree node.
112

113
    Args:
114
        cls (type): A Python type to treat as an internal pytree node.
115
        flatten_fn (callable): A function to be used during flattening, taking an instance of
116
            ``cls`` and returning a pair, with (1) an iterable for the children to be flattened
117
            recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be
118
            passed to the ``unflatten_fn``.
119
        unflatten_fn (callable): A function taking two arguments: the auxiliary data that was
120
            returned by ``flatten_fn`` and stored in the treespec, and the unflattened children.
121
            The function should return an instance of ``cls``.
122
        serialized_type_name (str, optional): A keyword argument used to specify the fully
123
            qualified name used when serializing the tree spec.
124
        to_dumpable_context (callable, optional): An optional keyword argument to custom specify how
125
            to convert the context of the pytree to a custom json dumpable representation. This is
126
            used for json serialization, which is being used in :mod:`torch.export` right now.
127
        from_dumpable_context (callable, optional): An optional keyword argument to custom specify
128
            how to convert the custom json dumpable representation of the context back to the
129
            original context. This is used for json deserialization, which is being used in
130
            :mod:`torch.export` right now.
131

132
    Example::
133

134
        >>> # xdoctest: +SKIP
135
        >>> # Registry a Python type with lambda functions
136
        >>> register_pytree_node(
137
        ...     set,
138
        ...     lambda s: (sorted(s), None, None),
139
        ...     lambda children, _: set(children),
140
        ... )
141
    """
142
    if flatten_with_keys_fn is not None:
143
        raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
144

145
    _private_register_pytree_node(
146
        cls,
147
        flatten_fn,
148
        unflatten_fn,
149
        serialized_type_name=serialized_type_name,
150
        to_dumpable_context=to_dumpable_context,
151
        from_dumpable_context=from_dumpable_context,
152
    )
153

154
    from . import _pytree as python
155

156
    python._private_register_pytree_node(
157
        cls,
158
        flatten_fn,
159
        unflatten_fn,
160
        serialized_type_name=serialized_type_name,
161
        to_dumpable_context=to_dumpable_context,
162
        from_dumpable_context=from_dumpable_context,
163
    )
164

165

166
@deprecated(
167
    "`torch.utils._cxx_pytree._register_pytree_node` is deprecated. "
168
    "Please use `torch.utils._cxx_pytree.register_pytree_node` instead.",
169
    category=FutureWarning,
170
)
171
def _register_pytree_node(
172
    cls: Type[Any],
173
    flatten_fn: FlattenFunc,
174
    unflatten_fn: UnflattenFunc,
175
    *,
176
    serialized_type_name: Optional[str] = None,
177
    to_dumpable_context: Optional[ToDumpableContextFn] = None,
178
    from_dumpable_context: Optional[FromDumpableContextFn] = None,
179
) -> None:
180
    """Register a container-like type as pytree node for the C++ pytree only.
181

182
    The ``namespace`` argument is used to avoid collisions that occur when different libraries
183
    register the same Python type with different behaviors. It is recommended to add a unique prefix
184
    to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify
185
    the same class in different namespaces for different use cases.
186

187
    .. warning::
188
        For safety reasons, a ``namespace`` must be specified while registering a custom type. It is
189
        used to isolate the behavior of flattening and unflattening a pytree node type. This is to
190
        prevent accidental collisions between different libraries that may register the same type.
191

192
    Args:
193
        cls (type): A Python type to treat as an internal pytree node.
194
        flatten_fn (callable): A function to be used during flattening, taking an instance of
195
            ``cls`` and returning a pair, with (1) an iterable for the children to be flattened
196
            recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be
197
            passed to the ``unflatten_fn``.
198
        unflatten_fn (callable): A function taking two arguments: the auxiliary data that was
199
            returned by ``flatten_fn`` and stored in the treespec, and the unflattened children.
200
            The function should return an instance of ``cls``.
201
        serialized_type_name (str, optional): A keyword argument used to specify the fully
202
            qualified name used when serializing the tree spec.
203
        to_dumpable_context (callable, optional): An optional keyword argument to custom specify how
204
            to convert the context of the pytree to a custom json dumpable representation. This is
205
            used for json serialization, which is being used in :mod:`torch.export` right now.
206
        from_dumpable_context (callable, optional): An optional keyword argument to custom specify
207
            how to convert the custom json dumpable representation of the context back to the
208
            original context. This is used for json deserialization, which is being used in
209
            :mod:`torch.export` right now.
210
    """
211

212
    _private_register_pytree_node(
213
        cls,
214
        flatten_fn,
215
        unflatten_fn,
216
        serialized_type_name=serialized_type_name,
217
        to_dumpable_context=to_dumpable_context,
218
        from_dumpable_context=from_dumpable_context,
219
    )
220

221

222
def _private_register_pytree_node(
223
    cls: Type[Any],
224
    flatten_fn: FlattenFunc,
225
    unflatten_fn: UnflattenFunc,
226
    *,
227
    serialized_type_name: Optional[str] = None,
228
    to_dumpable_context: Optional[ToDumpableContextFn] = None,
229
    from_dumpable_context: Optional[FromDumpableContextFn] = None,
230
) -> None:
231
    """This is an internal function that is used to register a pytree node type
232
    for the C++ pytree only. End-users should use :func:`register_pytree_node`
233
    instead.
234
    """
235
    # TODO(XuehaiPan): remove this condition when we make Python pytree out-of-box support
236
    # PyStructSequence types
237
    if not optree.is_structseq_class(cls):
238
        optree.register_pytree_node(
239
            cls,
240
            flatten_fn,
241
            _reverse_args(unflatten_fn),
242
            namespace="torch",
243
        )
244

245

246
def tree_flatten(
247
    tree: PyTree,
248
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
249
) -> Tuple[List[Any], TreeSpec]:
250
    """Flatten a pytree.
251

252
    See also :func:`tree_unflatten`.
253

254
    The flattening order (i.e., the order of elements in the output list) is deterministic,
255
    corresponding to a left-to-right depth-first tree traversal.
256

257
    >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
258
    >>> tree_flatten(tree)
259
    ([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf))
260
    >>> tree_flatten(1)
261
    ([1], PyTreeSpec(*, NoneIsLeaf))
262
    >>> tree_flatten(None)
263
    ([None], PyTreeSpec(*, NoneIsLeaf))
264

265
    For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is
266
    dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict`
267
    if you want to keep the keys in the insertion order.
268

269
    >>> from collections import OrderedDict
270
    >>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)])
271
    >>> tree_flatten(tree)
272
    ([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf))
273

274
    Args:
275
        tree (pytree): A pytree to flatten.
276
        is_leaf (callable, optional): An extra leaf predicate function that will be called at each
277
            flattening step. The function should have a single argument with signature
278
            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
279
            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
280
            leaf or not. If the function is not specified, the default pytree registry will be used.
281

282
    Returns:
283
        A pair ``(leaves, treespec)`` where the first element is a list of leaf values and the
284
        second element is a treespec representing the structure of the pytree.
285
    """
286
    return optree.tree_flatten(  # type: ignore[return-value]
287
        tree,
288
        is_leaf=is_leaf,
289
        none_is_leaf=True,
290
        namespace="torch",
291
    )
292

293

294
def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
295
    """Reconstruct a pytree from the treespec and the leaves.
296

297
    The inverse of :func:`tree_flatten`.
298

299
    >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
300
    >>> leaves, treespec = tree_flatten(tree)
301
    >>> tree == tree_unflatten(leaves, treespec)
302
    True
303

304
    Args:
305
        leaves (iterable): The list of leaves to use for reconstruction. The list must match the
306
            number of leaves of the treespec.
307
        treespec (TreeSpec): The treespec to reconstruct.
308

309
    Returns:
310
        The reconstructed pytree, containing the ``leaves`` placed in the structure described by
311
        ``treespec``.
312
    """
313
    if not isinstance(treespec, TreeSpec):
314
        raise TypeError(
315
            f"tree_unflatten(values, spec): Expected `spec` to be instance of "
316
            f"TreeSpec but got item of type {type(treespec)}."
317
        )
318
    return optree.tree_unflatten(treespec, leaves)  # type: ignore[arg-type]
319

320

321
def tree_iter(
322
    tree: PyTree,
323
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
324
) -> Iterable[Any]:
325
    """Get an iterator over the leaves of a pytree.
326

327
    See also :func:`tree_flatten`.
328

329
    >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
330
    >>> list(tree_iter(tree))
331
    [1, 2, 3, 4, None, 5]
332
    >>> list(tree_iter(1))
333
    [1]
334
    >>> list(tree_iter(None))
335
    [None]
336

337
    Args:
338
        tree (pytree): A pytree to flatten.
339
        is_leaf (callable, optional): An extra leaf predicate function that will be called at each
340
            flattening step. The function should have a single argument with signature
341
            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
342
            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
343
            leaf or not. If the function is not specified, the default pytree registry will be used.
344

345
    Returns:
346
        An iterator over the leaf values.
347
    """
348
    return optree.tree_iter(
349
        tree,
350
        is_leaf=is_leaf,
351
        none_is_leaf=True,
352
        namespace="torch",
353
    )
354

355

356
def tree_leaves(
357
    tree: PyTree,
358
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
359
) -> List[Any]:
360
    """Get the leaves of a pytree.
361

362
    See also :func:`tree_flatten`.
363

364
    >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
365
    >>> tree_leaves(tree)
366
    [1, 2, 3, 4, None, 5]
367
    >>> tree_leaves(1)
368
    [1]
369
    >>> tree_leaves(None)
370
    [None]
371

372
    Args:
373
        tree (pytree): A pytree to flatten.
374
        is_leaf (callable, optional): An extra leaf predicate function that will be called at each
375
            flattening step. The function should have a single argument with signature
376
            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
377
            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
378
            leaf or not. If the function is not specified, the default pytree registry will be used.
379

380
    Returns:
381
        A list of leaf values.
382
    """
383
    return optree.tree_leaves(
384
        tree,
385
        is_leaf=is_leaf,
386
        none_is_leaf=True,
387
        namespace="torch",
388
    )
389

390

391
def tree_structure(
392
    tree: PyTree,
393
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
394
) -> TreeSpec:
395
    """Get the treespec for a pytree.
396

397
    See also :func:`tree_flatten`.
398

399
    >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
400
    >>> tree_structure(tree)
401
    PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)
402
    >>> tree_structure(1)
403
    PyTreeSpec(*, NoneIsLeaf)
404
    >>> tree_structure(None)
405
    PyTreeSpec(*, NoneIsLeaf)
406

407
    Args:
408
        tree (pytree): A pytree to flatten.
409
        is_leaf (callable, optional): An extra leaf predicate function that will be called at each
410
            flattening step. The function should have a single argument with signature
411
            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
412
            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
413
            leaf or not. If the function is not specified, the default pytree registry will be used.
414

415
    Returns:
416
        A treespec object representing the structure of the pytree.
417
    """
418
    return optree.tree_structure(  # type: ignore[return-value]
419
        tree,
420
        is_leaf=is_leaf,
421
        none_is_leaf=True,
422
        namespace="torch",
423
    )
424

425

426
def tree_map(
427
    func: Callable[..., Any],
428
    tree: PyTree,
429
    *rests: PyTree,
430
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
431
) -> PyTree:
432
    """Map a multi-input function over pytree args to produce a new pytree.
433

434
    See also :func:`tree_map_`.
435

436
    >>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)})
437
    {'x': 8, 'y': (43, 65)}
438
    >>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None})
439
    {'x': False, 'y': (False, False), 'z': True}
440

441
    If multiple inputs are given, the structure of the tree is taken from the first input;
442
    subsequent inputs need only have ``tree`` as a prefix:
443

444
    >>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]])
445
    [[5, 7, 9], [6, 1, 2]]
446

447
    Args:
448
        func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
449
            corresponding leaves of the pytrees.
450
        tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
451
            argument to function ``func``.
452
        rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
453
            ``tree`` or has ``tree`` as a prefix.
454
        is_leaf (callable, optional): An extra leaf predicate function that will be called at each
455
            flattening step. The function should have a single argument with signature
456
            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
457
            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
458
            leaf or not. If the function is not specified, the default pytree registry will be used.
459

460
    Returns:
461
        A new pytree with the same structure as ``tree`` but with the value at each leaf given by
462
        ``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs``
463
        is the tuple of values at corresponding nodes in ``rests``.
464
    """
465
    return optree.tree_map(
466
        func,
467
        tree,
468
        *rests,
469
        is_leaf=is_leaf,
470
        none_is_leaf=True,
471
        namespace="torch",
472
    )
473

474

475
def tree_map_(
476
    func: Callable[..., Any],
477
    tree: PyTree,
478
    *rests: PyTree,
479
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
480
) -> PyTree:
481
    """Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree.
482

483
    See also :func:`tree_map`.
484

485
    Args:
486
        func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
487
            corresponding leaves of the pytrees.
488
        tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
489
            argument to function ``func``.
490
        rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
491
            ``tree`` or has ``tree`` as a prefix.
492
        is_leaf (callable, optional): An extra leaf predicate function that will be called at each
493
            flattening step. The function should have a single argument with signature
494
            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
495
            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
496
            leaf or not. If the function is not specified, the default pytree registry will be used.
497

498
    Returns:
499
        The original ``tree`` with the value at each leaf is given by the side-effect of function
500
        ``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf
501
        in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``.
502
    """
503
    return optree.tree_map_(
504
        func,
505
        tree,
506
        *rests,
507
        is_leaf=is_leaf,
508
        none_is_leaf=True,
509
        namespace="torch",
510
    )
511

512

513
Type2 = Tuple[Type[T], Type[S]]
514
Type3 = Tuple[Type[T], Type[S], Type[U]]
515
if sys.version_info >= (3, 10):
516
    TypeAny = Union[Type[Any], Tuple[Type[Any], ...], types.UnionType]
517
else:
518
    TypeAny = Union[Type[Any], Tuple[Type[Any], ...]]
519

520
Fn2 = Callable[[Union[T, S]], R]
521
Fn3 = Callable[[Union[T, S, U]], R]
522
Fn = Callable[[T], R]
523
FnAny = Callable[[Any], R]
524

525
MapOnlyFn = Callable[[T], Callable[[Any], Any]]
526

527

528
# These specializations help with type inference on the lambda passed to this
529
# function
530
@overload
531
def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
532
    ...
533

534

535
@overload
536
def map_only(__type_or_types_or_pred: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]:
537
    ...
538

539

540
@overload
541
def map_only(__type_or_types_or_pred: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
542
    ...
543

544

545
# This specialization is needed for the implementations below that call
546
@overload
547
def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]:
548
    ...
549

550

551
@overload
552
def map_only(__type_or_types_or_pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]:
553
    ...
554

555

556
def map_only(
557
    __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]]
558
) -> MapOnlyFn[FnAny[Any]]:
559
    """
560
    Suppose you are writing a tree_map over tensors, leaving everything
561
    else unchanged.  Ordinarily you would have to write:
562

563
        def go(t):
564
            if isinstance(t, Tensor):
565
                return ...
566
            else:
567
                return t
568

569
    With this function, you only need to write:
570

571
        @map_only(Tensor)
572
        def go(t):
573
            return ...
574

575
    You can also directly use 'tree_map_only'
576
    """
577
    if isinstance(__type_or_types_or_pred, (type, tuple)) or (
578
        sys.version_info >= (3, 10)
579
        and isinstance(__type_or_types_or_pred, types.UnionType)
580
    ):
581

582
        def pred(x: Any) -> bool:
583
            return isinstance(x, __type_or_types_or_pred)  # type: ignore[arg-type]
584

585
    elif callable(__type_or_types_or_pred):
586
        pred = __type_or_types_or_pred  # type: ignore[assignment]
587
    else:
588
        raise TypeError("Argument must be a type, a tuple of types, or a callable.")
589

590
    def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]:
591
        @functools.wraps(func)
592
        def wrapped(x: T) -> Any:
593
            if pred(x):
594
                return func(x)
595
            return x
596

597
        return wrapped
598

599
    return wrapper
600

601

602
@overload
603
def tree_map_only(
604
    __type_or_types_or_pred: Type[T],
605
    func: Fn[T, Any],
606
    tree: PyTree,
607
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
608
) -> PyTree:
609
    ...
610

611

612
@overload
613
def tree_map_only(
614
    __type_or_types_or_pred: Type2[T, S],
615
    func: Fn2[T, S, Any],
616
    tree: PyTree,
617
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
618
) -> PyTree:
619
    ...
620

621

622
@overload
623
def tree_map_only(
624
    __type_or_types_or_pred: Type3[T, S, U],
625
    func: Fn3[T, S, U, Any],
626
    tree: PyTree,
627
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
628
) -> PyTree:
629
    ...
630

631

632
@overload
633
def tree_map_only(
634
    __type_or_types_or_pred: Callable[[Any], bool],
635
    func: FnAny[Any],
636
    tree: PyTree,
637
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
638
) -> PyTree:
639
    ...
640

641

642
def tree_map_only(
643
    __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
644
    func: FnAny[Any],
645
    tree: PyTree,
646
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
647
) -> PyTree:
648
    return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
649

650

651
@overload
652
def tree_map_only_(
653
    __type_or_types_or_pred: Type[T],
654
    func: Fn[T, Any],
655
    tree: PyTree,
656
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
657
) -> PyTree:
658
    ...
659

660

661
@overload
662
def tree_map_only_(
663
    __type_or_types_or_pred: Type2[T, S],
664
    func: Fn2[T, S, Any],
665
    tree: PyTree,
666
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
667
) -> PyTree:
668
    ...
669

670

671
@overload
672
def tree_map_only_(
673
    __type_or_types_or_pred: Type3[T, S, U],
674
    func: Fn3[T, S, U, Any],
675
    tree: PyTree,
676
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
677
) -> PyTree:
678
    ...
679

680

681
@overload
682
def tree_map_only_(
683
    __type_or_types_or_pred: Callable[[Any], bool],
684
    func: FnAny[Any],
685
    tree: PyTree,
686
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
687
) -> PyTree:
688
    ...
689

690

691
def tree_map_only_(
692
    __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
693
    func: FnAny[Any],
694
    tree: PyTree,
695
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
696
) -> PyTree:
697
    return tree_map_(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
698

699

700
def tree_all(
701
    pred: Callable[[Any], bool],
702
    tree: PyTree,
703
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
704
) -> bool:
705
    flat_args = tree_iter(tree, is_leaf=is_leaf)
706
    return all(map(pred, flat_args))
707

708

709
def tree_any(
710
    pred: Callable[[Any], bool],
711
    tree: PyTree,
712
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
713
) -> bool:
714
    flat_args = tree_iter(tree, is_leaf=is_leaf)
715
    return any(map(pred, flat_args))
716

717

718
@overload
719
def tree_all_only(
720
    __type_or_types: Type[T],
721
    pred: Fn[T, bool],
722
    tree: PyTree,
723
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
724
) -> bool:
725
    ...
726

727

728
@overload
729
def tree_all_only(
730
    __type_or_types: Type2[T, S],
731
    pred: Fn2[T, S, bool],
732
    tree: PyTree,
733
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
734
) -> bool:
735
    ...
736

737

738
@overload
739
def tree_all_only(
740
    __type_or_types: Type3[T, S, U],
741
    pred: Fn3[T, S, U, bool],
742
    tree: PyTree,
743
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
744
) -> bool:
745
    ...
746

747

748
def tree_all_only(
749
    __type_or_types: TypeAny,
750
    pred: FnAny[bool],
751
    tree: PyTree,
752
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
753
) -> bool:
754
    flat_args = tree_iter(tree, is_leaf=is_leaf)
755
    return all(pred(x) for x in flat_args if isinstance(x, __type_or_types))
756

757

758
@overload
759
def tree_any_only(
760
    __type_or_types: Type[T],
761
    pred: Fn[T, bool],
762
    tree: PyTree,
763
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
764
) -> bool:
765
    ...
766

767

768
@overload
769
def tree_any_only(
770
    __type_or_types: Type2[T, S],
771
    pred: Fn2[T, S, bool],
772
    tree: PyTree,
773
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
774
) -> bool:
775
    ...
776

777

778
@overload
779
def tree_any_only(
780
    __type_or_types: Type3[T, S, U],
781
    pred: Fn3[T, S, U, bool],
782
    tree: PyTree,
783
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
784
) -> bool:
785
    ...
786

787

788
def tree_any_only(
789
    __type_or_types: TypeAny,
790
    pred: FnAny[bool],
791
    tree: PyTree,
792
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
793
) -> bool:
794
    flat_args = tree_iter(tree, is_leaf=is_leaf)
795
    return any(pred(x) for x in flat_args if isinstance(x, __type_or_types))
796

797

798
def broadcast_prefix(
799
    prefix_tree: PyTree,
800
    full_tree: PyTree,
801
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
802
) -> List[Any]:
803
    """Return a list of broadcasted leaves in ``prefix_tree`` to match the number of leaves in ``full_tree``.
804

805
    If a ``prefix_tree`` is a prefix of a ``full_tree``, this means the ``full_tree`` can be
806
    constructed by replacing the leaves of ``prefix_tree`` with appropriate **subtrees**.
807

808
    This function returns a list of leaves with the same size as ``full_tree``. The leaves are
809
    replicated from ``prefix_tree``. The number of replicas is determined by the corresponding
810
    subtree in ``full_tree``.
811

812
    >>> broadcast_prefix(1, [1, 2, 3])
813
    [1, 1, 1]
814
    >>> broadcast_prefix([1, 2, 3], [1, 2, 3])
815
    [1, 2, 3]
816
    >>> broadcast_prefix([1, 2, 3], [1, 2, 3, 4])
817
    Traceback (most recent call last):
818
        ...
819
    ValueError: list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].
820
    >>> broadcast_prefix([1, 2, 3], [1, 2, (3, 4)])
821
    [1, 2, 3, 3]
822
    >>> broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}])
823
    [1, 2, 3, 3, 3, 3]
824

825
    Args:
826
        prefix_tree (pytree): A pytree with the same structure as a prefix of ``full_tree``.
827
        full_tree (pytree): A pytree with the same structure as a suffix of ``prefix_tree``.
828
        is_leaf (callable, optional): An extra leaf predicate function that will be called at each
829
            flattening step. The function should have a single argument with signature
830
            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
831
            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
832
            leaf or not. If the function is not specified, the default pytree registry will be used.
833

834
    Returns:
835
        A list of leaves in ``prefix_tree`` broadcasted to match the number of leaves in ``full_tree``.
836
    """
837
    return optree.broadcast_prefix(
838
        prefix_tree,
839
        full_tree,
840
        is_leaf=is_leaf,
841
        none_is_leaf=True,
842
        namespace="torch",
843
    )
844

845

846
# Broadcasts a pytree to the provided TreeSpec and returns the flattened
847
# values. If this is not possible, then this function returns None.
848
#
849
# For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]),
850
# would return [0, 0]. This is useful for part of the vmap implementation:
851
# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be
852
# broadcastable to the tree structure of `inputs` and we use
853
# _broadcast_to_and_flatten to check this.
854
def _broadcast_to_and_flatten(
855
    tree: PyTree,
856
    treespec: TreeSpec,
857
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
858
) -> Optional[List[Any]]:
859
    assert isinstance(treespec, TreeSpec)
860
    full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
861
    try:
862
        return broadcast_prefix(tree, full_tree, is_leaf=is_leaf)
863
    except ValueError:
864
        return None
865

866

867
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
868
    """Serialize a treespec to a JSON string."""
869
    if not isinstance(treespec, TreeSpec):
870
        raise TypeError(
871
            f"treespec_dumps(spec): Expected `spec` to be instance of "
872
            f"TreeSpec but got item of type {type(treespec)}."
873
        )
874
    from ._pytree import (
875
        tree_structure as _tree_structure,
876
        treespec_dumps as _treespec_dumps,
877
    )
878

879
    orig_treespec = _tree_structure(tree_unflatten([0] * treespec.num_leaves, treespec))
880
    return _treespec_dumps(orig_treespec, protocol=protocol)
881

882

883
def treespec_loads(serialized: str) -> TreeSpec:
884
    """Deserialize a treespec from a JSON string."""
885
    from ._pytree import (
886
        tree_unflatten as _tree_unflatten,
887
        treespec_loads as _treespec_loads,
888
    )
889

890
    orig_treespec = _treespec_loads(serialized)
891
    dummy_tree = _tree_unflatten([0] * orig_treespec.num_leaves, orig_treespec)
892
    treespec = tree_structure(dummy_tree)
893
    return treespec
894

895

896
class _DummyLeaf:
897
    def __repr__(self) -> str:
898
        return "*"
899

900

901
def treespec_pprint(treespec: TreeSpec) -> str:
902
    dummy_tree = tree_unflatten(
903
        [_DummyLeaf() for _ in range(treespec.num_leaves)],
904
        treespec,
905
    )
906
    return repr(dummy_tree)
907

908

909
class LeafSpecMeta(type(TreeSpec)):  # type: ignore[misc]
910
    def __instancecheck__(self, instance: object) -> bool:
911
        return isinstance(instance, TreeSpec) and instance.is_leaf()
912

913

914
class LeafSpec(TreeSpec, metaclass=LeafSpecMeta):
915
    def __new__(cls) -> "LeafSpec":
916
        return optree.treespec_leaf(none_is_leaf=True)  # type: ignore[return-value]
917

918

919
def tree_flatten_with_path(
920
    tree: PyTree,
921
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
922
) -> Tuple[List[Tuple[KeyPath, Any]], TreeSpec]:
923
    """Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path.
924

925
    Args:
926
        tree: a pytree to flatten. If it contains a custom type, that type must be
927
            registered with an appropriate `tree_flatten_with_path_fn` when registered
928
            with :func:`register_pytree_node`.
929
        is_leaf: An extra leaf predicate function that will be called at each
930
            flattening step. The function should have a single argument with signature
931
            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
932
            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
933
            leaf or not. If the function is not specified, the default pytree registry will be used.
934
    Returns:
935
        A tuple where the first element is a list of (key path, leaf) pairs, and the
936
        second element is a :class:`TreeSpec` representing the structure of the flattened
937
        tree.
938
    """
939
    raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
940

941

942
def tree_leaves_with_path(
943
    tree: PyTree,
944
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
945
) -> List[Tuple[KeyPath, Any]]:
946
    """Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
947

948
    Args:
949
        tree: a pytree. If it contains a custom type, that type must be
950
            registered with an appropriate `tree_flatten_with_path_fn` when registered
951
            with :func:`register_pytree_node`.
952
        is_leaf: An extra leaf predicate function that will be called at each
953
            flattening step. The function should have a single argument with signature
954
            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
955
            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
956
            leaf or not. If the function is not specified, the default pytree registry will be used.
957
    Returns:
958
        A list of (key path, leaf) pairs.
959
    """
960
    raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
961

962

963
def tree_map_with_path(
964
    func: Callable[..., Any],
965
    tree: PyTree,
966
    *rests: PyTree,
967
    is_leaf: Optional[Callable[[PyTree], bool]] = None,
968
) -> PyTree:
969
    """Like :func:`tree_map`, but the provided callable takes an additional key path argument.
970

971
    Args:
972
        func: A function that takes ``2 + len(rests)`` arguments, to be applied at the
973
            corresponding leaves of the pytrees. The first positional argument
974
            to ``func`` is the key path of the leaf in question. The second
975
            positional argument is the value of the leaf.
976
        tree: A pytree to be mapped over, with each leaf providing the first positional
977
            argument to function ``func``.
978
        rests: A tuple of pytrees, each of which has the same structure as
979
            ``tree`` or has ``tree`` as a prefix.
980
        is_leaf: An extra leaf predicate function that will be called at each
981
            flattening step. The function should have a single argument with signature
982
            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
983
            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
984
            leaf or not. If the function is not specified, the default pytree registry will be used.
985

986
    Returns
987
        A new pytree with the same structure as ``tree`` but with the value at each leaf given by
988
        ``func(keypath, x, *xs)`` where ``keypath`` is the key path at the
989
        corresponding leaf in ``tree``, ``x`` is the value at that leaf, and
990
        ``xs`` is the tuple of values at corresponding nodes in ``rests``.
991
    """
992
    raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
993

994

995
def keystr(kp: KeyPath) -> str:
996
    """Given a key path, return a pretty-printed representation."""
997
    raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
998

999

1000
def key_get(obj: Any, kp: KeyPath) -> Any:
1001
    """Given an object and a key path, return the value at the key path."""
1002
    raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
1003

1004

1005
_pytree._cxx_pytree_imported = True
1006
for args, kwargs in _pytree._cxx_pytree_pending_imports:
1007
    _private_register_pytree_node(*args, **kwargs)
1008

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.