2
Contains utility functions for working with nested python data structures.
4
A *pytree* is Python nested data structure. It is a tree in the sense that
5
nodes are Python collections (e.g., list, tuple, dict) and the leaves are
6
Python values. Furthermore, a pytree should not contain reference cycles.
8
pytrees are useful for working with nested collections of Tensors. For example,
9
one can use `tree_map` to map a function over all Tensors inside some nested
10
collection of Tensors and `tree_leaves` to get a flat list of all Tensors
11
inside some nested collection. pytrees are helpful for implementing nested
12
collection support for PyTorch APIs.
30
from typing_extensions import deprecated
33
from optree import PyTreeSpec # direct import for type annotations
35
import torch.utils._pytree as _pytree
36
from torch.utils._pytree import KeyEntry
45
"ToDumpableContextFn",
46
"FromDumpableContextFn",
51
"register_pytree_node",
53
"tree_flatten_with_path",
57
"tree_leaves_with_path",
83
FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]]
84
UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
85
OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree]
86
DumpableContext = Any # Any json dumpable text
87
ToDumpableContextFn = Callable[[Context], DumpableContext]
88
FromDumpableContextFn = Callable[[DumpableContext], Context]
89
KeyPath = Tuple[KeyEntry, ...]
90
FlattenWithKeysFunc = Callable[[PyTree], Tuple[List[Tuple[KeyEntry, Any]], Any]]
93
def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
94
@functools.wraps(func)
95
def wrapped(*args: Any, **kwargs: Any) -> Any:
96
return func(*reversed(args), **kwargs)
101
def register_pytree_node(
103
flatten_fn: FlattenFunc,
104
unflatten_fn: UnflattenFunc,
106
serialized_type_name: Optional[str] = None,
107
to_dumpable_context: Optional[ToDumpableContextFn] = None,
108
from_dumpable_context: Optional[FromDumpableContextFn] = None,
109
flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
111
"""Register a container-like type as pytree node.
114
cls (type): A Python type to treat as an internal pytree node.
115
flatten_fn (callable): A function to be used during flattening, taking an instance of
116
``cls`` and returning a pair, with (1) an iterable for the children to be flattened
117
recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be
118
passed to the ``unflatten_fn``.
119
unflatten_fn (callable): A function taking two arguments: the auxiliary data that was
120
returned by ``flatten_fn`` and stored in the treespec, and the unflattened children.
121
The function should return an instance of ``cls``.
122
serialized_type_name (str, optional): A keyword argument used to specify the fully
123
qualified name used when serializing the tree spec.
124
to_dumpable_context (callable, optional): An optional keyword argument to custom specify how
125
to convert the context of the pytree to a custom json dumpable representation. This is
126
used for json serialization, which is being used in :mod:`torch.export` right now.
127
from_dumpable_context (callable, optional): An optional keyword argument to custom specify
128
how to convert the custom json dumpable representation of the context back to the
129
original context. This is used for json deserialization, which is being used in
130
:mod:`torch.export` right now.
134
>>> # xdoctest: +SKIP
135
>>> # Registry a Python type with lambda functions
136
>>> register_pytree_node(
138
... lambda s: (sorted(s), None, None),
139
... lambda children, _: set(children),
142
if flatten_with_keys_fn is not None:
143
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
145
_private_register_pytree_node(
149
serialized_type_name=serialized_type_name,
150
to_dumpable_context=to_dumpable_context,
151
from_dumpable_context=from_dumpable_context,
154
from . import _pytree as python
156
python._private_register_pytree_node(
160
serialized_type_name=serialized_type_name,
161
to_dumpable_context=to_dumpable_context,
162
from_dumpable_context=from_dumpable_context,
167
"`torch.utils._cxx_pytree._register_pytree_node` is deprecated. "
168
"Please use `torch.utils._cxx_pytree.register_pytree_node` instead.",
169
category=FutureWarning,
171
def _register_pytree_node(
173
flatten_fn: FlattenFunc,
174
unflatten_fn: UnflattenFunc,
176
serialized_type_name: Optional[str] = None,
177
to_dumpable_context: Optional[ToDumpableContextFn] = None,
178
from_dumpable_context: Optional[FromDumpableContextFn] = None,
180
"""Register a container-like type as pytree node for the C++ pytree only.
182
The ``namespace`` argument is used to avoid collisions that occur when different libraries
183
register the same Python type with different behaviors. It is recommended to add a unique prefix
184
to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify
185
the same class in different namespaces for different use cases.
188
For safety reasons, a ``namespace`` must be specified while registering a custom type. It is
189
used to isolate the behavior of flattening and unflattening a pytree node type. This is to
190
prevent accidental collisions between different libraries that may register the same type.
193
cls (type): A Python type to treat as an internal pytree node.
194
flatten_fn (callable): A function to be used during flattening, taking an instance of
195
``cls`` and returning a pair, with (1) an iterable for the children to be flattened
196
recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be
197
passed to the ``unflatten_fn``.
198
unflatten_fn (callable): A function taking two arguments: the auxiliary data that was
199
returned by ``flatten_fn`` and stored in the treespec, and the unflattened children.
200
The function should return an instance of ``cls``.
201
serialized_type_name (str, optional): A keyword argument used to specify the fully
202
qualified name used when serializing the tree spec.
203
to_dumpable_context (callable, optional): An optional keyword argument to custom specify how
204
to convert the context of the pytree to a custom json dumpable representation. This is
205
used for json serialization, which is being used in :mod:`torch.export` right now.
206
from_dumpable_context (callable, optional): An optional keyword argument to custom specify
207
how to convert the custom json dumpable representation of the context back to the
208
original context. This is used for json deserialization, which is being used in
209
:mod:`torch.export` right now.
212
_private_register_pytree_node(
216
serialized_type_name=serialized_type_name,
217
to_dumpable_context=to_dumpable_context,
218
from_dumpable_context=from_dumpable_context,
222
def _private_register_pytree_node(
224
flatten_fn: FlattenFunc,
225
unflatten_fn: UnflattenFunc,
227
serialized_type_name: Optional[str] = None,
228
to_dumpable_context: Optional[ToDumpableContextFn] = None,
229
from_dumpable_context: Optional[FromDumpableContextFn] = None,
231
"""This is an internal function that is used to register a pytree node type
232
for the C++ pytree only. End-users should use :func:`register_pytree_node`
235
# TODO(XuehaiPan): remove this condition when we make Python pytree out-of-box support
236
# PyStructSequence types
237
if not optree.is_structseq_class(cls):
238
optree.register_pytree_node(
241
_reverse_args(unflatten_fn),
248
is_leaf: Optional[Callable[[PyTree], bool]] = None,
249
) -> Tuple[List[Any], TreeSpec]:
252
See also :func:`tree_unflatten`.
254
The flattening order (i.e., the order of elements in the output list) is deterministic,
255
corresponding to a left-to-right depth-first tree traversal.
257
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
258
>>> tree_flatten(tree)
259
([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf))
261
([1], PyTreeSpec(*, NoneIsLeaf))
262
>>> tree_flatten(None)
263
([None], PyTreeSpec(*, NoneIsLeaf))
265
For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is
266
dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict`
267
if you want to keep the keys in the insertion order.
269
>>> from collections import OrderedDict
270
>>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)])
271
>>> tree_flatten(tree)
272
([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf))
275
tree (pytree): A pytree to flatten.
276
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
277
flattening step. The function should have a single argument with signature
278
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
279
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
280
leaf or not. If the function is not specified, the default pytree registry will be used.
283
A pair ``(leaves, treespec)`` where the first element is a list of leaf values and the
284
second element is a treespec representing the structure of the pytree.
286
return optree.tree_flatten( # type: ignore[return-value]
294
def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
295
"""Reconstruct a pytree from the treespec and the leaves.
297
The inverse of :func:`tree_flatten`.
299
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
300
>>> leaves, treespec = tree_flatten(tree)
301
>>> tree == tree_unflatten(leaves, treespec)
305
leaves (iterable): The list of leaves to use for reconstruction. The list must match the
306
number of leaves of the treespec.
307
treespec (TreeSpec): The treespec to reconstruct.
310
The reconstructed pytree, containing the ``leaves`` placed in the structure described by
313
if not isinstance(treespec, TreeSpec):
315
f"tree_unflatten(values, spec): Expected `spec` to be instance of "
316
f"TreeSpec but got item of type {type(treespec)}."
318
return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
323
is_leaf: Optional[Callable[[PyTree], bool]] = None,
325
"""Get an iterator over the leaves of a pytree.
327
See also :func:`tree_flatten`.
329
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
330
>>> list(tree_iter(tree))
331
[1, 2, 3, 4, None, 5]
332
>>> list(tree_iter(1))
334
>>> list(tree_iter(None))
338
tree (pytree): A pytree to flatten.
339
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
340
flattening step. The function should have a single argument with signature
341
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
342
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
343
leaf or not. If the function is not specified, the default pytree registry will be used.
346
An iterator over the leaf values.
348
return optree.tree_iter(
358
is_leaf: Optional[Callable[[PyTree], bool]] = None,
360
"""Get the leaves of a pytree.
362
See also :func:`tree_flatten`.
364
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
365
>>> tree_leaves(tree)
366
[1, 2, 3, 4, None, 5]
369
>>> tree_leaves(None)
373
tree (pytree): A pytree to flatten.
374
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
375
flattening step. The function should have a single argument with signature
376
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
377
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
378
leaf or not. If the function is not specified, the default pytree registry will be used.
381
A list of leaf values.
383
return optree.tree_leaves(
393
is_leaf: Optional[Callable[[PyTree], bool]] = None,
395
"""Get the treespec for a pytree.
397
See also :func:`tree_flatten`.
399
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
400
>>> tree_structure(tree)
401
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)
402
>>> tree_structure(1)
403
PyTreeSpec(*, NoneIsLeaf)
404
>>> tree_structure(None)
405
PyTreeSpec(*, NoneIsLeaf)
408
tree (pytree): A pytree to flatten.
409
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
410
flattening step. The function should have a single argument with signature
411
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
412
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
413
leaf or not. If the function is not specified, the default pytree registry will be used.
416
A treespec object representing the structure of the pytree.
418
return optree.tree_structure( # type: ignore[return-value]
427
func: Callable[..., Any],
430
is_leaf: Optional[Callable[[PyTree], bool]] = None,
432
"""Map a multi-input function over pytree args to produce a new pytree.
434
See also :func:`tree_map_`.
436
>>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)})
437
{'x': 8, 'y': (43, 65)}
438
>>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None})
439
{'x': False, 'y': (False, False), 'z': True}
441
If multiple inputs are given, the structure of the tree is taken from the first input;
442
subsequent inputs need only have ``tree`` as a prefix:
444
>>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]])
445
[[5, 7, 9], [6, 1, 2]]
448
func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
449
corresponding leaves of the pytrees.
450
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
451
argument to function ``func``.
452
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
453
``tree`` or has ``tree`` as a prefix.
454
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
455
flattening step. The function should have a single argument with signature
456
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
457
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
458
leaf or not. If the function is not specified, the default pytree registry will be used.
461
A new pytree with the same structure as ``tree`` but with the value at each leaf given by
462
``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs``
463
is the tuple of values at corresponding nodes in ``rests``.
465
return optree.tree_map(
476
func: Callable[..., Any],
479
is_leaf: Optional[Callable[[PyTree], bool]] = None,
481
"""Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree.
483
See also :func:`tree_map`.
486
func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
487
corresponding leaves of the pytrees.
488
tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
489
argument to function ``func``.
490
rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
491
``tree`` or has ``tree`` as a prefix.
492
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
493
flattening step. The function should have a single argument with signature
494
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
495
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
496
leaf or not. If the function is not specified, the default pytree registry will be used.
499
The original ``tree`` with the value at each leaf is given by the side-effect of function
500
``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf
501
in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``.
503
return optree.tree_map_(
513
Type2 = Tuple[Type[T], Type[S]]
514
Type3 = Tuple[Type[T], Type[S], Type[U]]
515
if sys.version_info >= (3, 10):
516
TypeAny = Union[Type[Any], Tuple[Type[Any], ...], types.UnionType]
518
TypeAny = Union[Type[Any], Tuple[Type[Any], ...]]
520
Fn2 = Callable[[Union[T, S]], R]
521
Fn3 = Callable[[Union[T, S, U]], R]
523
FnAny = Callable[[Any], R]
525
MapOnlyFn = Callable[[T], Callable[[Any], Any]]
528
# These specializations help with type inference on the lambda passed to this
531
def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
536
def map_only(__type_or_types_or_pred: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]:
541
def map_only(__type_or_types_or_pred: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
545
# This specialization is needed for the implementations below that call
547
def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]:
552
def map_only(__type_or_types_or_pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]:
557
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]]
558
) -> MapOnlyFn[FnAny[Any]]:
560
Suppose you are writing a tree_map over tensors, leaving everything
561
else unchanged. Ordinarily you would have to write:
564
if isinstance(t, Tensor):
569
With this function, you only need to write:
575
You can also directly use 'tree_map_only'
577
if isinstance(__type_or_types_or_pred, (type, tuple)) or (
578
sys.version_info >= (3, 10)
579
and isinstance(__type_or_types_or_pred, types.UnionType)
582
def pred(x: Any) -> bool:
583
return isinstance(x, __type_or_types_or_pred) # type: ignore[arg-type]
585
elif callable(__type_or_types_or_pred):
586
pred = __type_or_types_or_pred # type: ignore[assignment]
588
raise TypeError("Argument must be a type, a tuple of types, or a callable.")
590
def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]:
591
@functools.wraps(func)
592
def wrapped(x: T) -> Any:
604
__type_or_types_or_pred: Type[T],
607
is_leaf: Optional[Callable[[PyTree], bool]] = None,
614
__type_or_types_or_pred: Type2[T, S],
615
func: Fn2[T, S, Any],
617
is_leaf: Optional[Callable[[PyTree], bool]] = None,
624
__type_or_types_or_pred: Type3[T, S, U],
625
func: Fn3[T, S, U, Any],
627
is_leaf: Optional[Callable[[PyTree], bool]] = None,
634
__type_or_types_or_pred: Callable[[Any], bool],
637
is_leaf: Optional[Callable[[PyTree], bool]] = None,
643
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
646
is_leaf: Optional[Callable[[PyTree], bool]] = None,
648
return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
653
__type_or_types_or_pred: Type[T],
656
is_leaf: Optional[Callable[[PyTree], bool]] = None,
663
__type_or_types_or_pred: Type2[T, S],
664
func: Fn2[T, S, Any],
666
is_leaf: Optional[Callable[[PyTree], bool]] = None,
673
__type_or_types_or_pred: Type3[T, S, U],
674
func: Fn3[T, S, U, Any],
676
is_leaf: Optional[Callable[[PyTree], bool]] = None,
683
__type_or_types_or_pred: Callable[[Any], bool],
686
is_leaf: Optional[Callable[[PyTree], bool]] = None,
692
__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
695
is_leaf: Optional[Callable[[PyTree], bool]] = None,
697
return tree_map_(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
701
pred: Callable[[Any], bool],
703
is_leaf: Optional[Callable[[PyTree], bool]] = None,
705
flat_args = tree_iter(tree, is_leaf=is_leaf)
706
return all(map(pred, flat_args))
710
pred: Callable[[Any], bool],
712
is_leaf: Optional[Callable[[PyTree], bool]] = None,
714
flat_args = tree_iter(tree, is_leaf=is_leaf)
715
return any(map(pred, flat_args))
720
__type_or_types: Type[T],
723
is_leaf: Optional[Callable[[PyTree], bool]] = None,
730
__type_or_types: Type2[T, S],
731
pred: Fn2[T, S, bool],
733
is_leaf: Optional[Callable[[PyTree], bool]] = None,
740
__type_or_types: Type3[T, S, U],
741
pred: Fn3[T, S, U, bool],
743
is_leaf: Optional[Callable[[PyTree], bool]] = None,
749
__type_or_types: TypeAny,
752
is_leaf: Optional[Callable[[PyTree], bool]] = None,
754
flat_args = tree_iter(tree, is_leaf=is_leaf)
755
return all(pred(x) for x in flat_args if isinstance(x, __type_or_types))
760
__type_or_types: Type[T],
763
is_leaf: Optional[Callable[[PyTree], bool]] = None,
770
__type_or_types: Type2[T, S],
771
pred: Fn2[T, S, bool],
773
is_leaf: Optional[Callable[[PyTree], bool]] = None,
780
__type_or_types: Type3[T, S, U],
781
pred: Fn3[T, S, U, bool],
783
is_leaf: Optional[Callable[[PyTree], bool]] = None,
789
__type_or_types: TypeAny,
792
is_leaf: Optional[Callable[[PyTree], bool]] = None,
794
flat_args = tree_iter(tree, is_leaf=is_leaf)
795
return any(pred(x) for x in flat_args if isinstance(x, __type_or_types))
801
is_leaf: Optional[Callable[[PyTree], bool]] = None,
803
"""Return a list of broadcasted leaves in ``prefix_tree`` to match the number of leaves in ``full_tree``.
805
If a ``prefix_tree`` is a prefix of a ``full_tree``, this means the ``full_tree`` can be
806
constructed by replacing the leaves of ``prefix_tree`` with appropriate **subtrees**.
808
This function returns a list of leaves with the same size as ``full_tree``. The leaves are
809
replicated from ``prefix_tree``. The number of replicas is determined by the corresponding
810
subtree in ``full_tree``.
812
>>> broadcast_prefix(1, [1, 2, 3])
814
>>> broadcast_prefix([1, 2, 3], [1, 2, 3])
816
>>> broadcast_prefix([1, 2, 3], [1, 2, 3, 4])
817
Traceback (most recent call last):
819
ValueError: list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].
820
>>> broadcast_prefix([1, 2, 3], [1, 2, (3, 4)])
822
>>> broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}])
826
prefix_tree (pytree): A pytree with the same structure as a prefix of ``full_tree``.
827
full_tree (pytree): A pytree with the same structure as a suffix of ``prefix_tree``.
828
is_leaf (callable, optional): An extra leaf predicate function that will be called at each
829
flattening step. The function should have a single argument with signature
830
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
831
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
832
leaf or not. If the function is not specified, the default pytree registry will be used.
835
A list of leaves in ``prefix_tree`` broadcasted to match the number of leaves in ``full_tree``.
837
return optree.broadcast_prefix(
846
# Broadcasts a pytree to the provided TreeSpec and returns the flattened
847
# values. If this is not possible, then this function returns None.
849
# For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]),
850
# would return [0, 0]. This is useful for part of the vmap implementation:
851
# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be
852
# broadcastable to the tree structure of `inputs` and we use
853
# _broadcast_to_and_flatten to check this.
854
def _broadcast_to_and_flatten(
857
is_leaf: Optional[Callable[[PyTree], bool]] = None,
858
) -> Optional[List[Any]]:
859
assert isinstance(treespec, TreeSpec)
860
full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
862
return broadcast_prefix(tree, full_tree, is_leaf=is_leaf)
867
def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
868
"""Serialize a treespec to a JSON string."""
869
if not isinstance(treespec, TreeSpec):
871
f"treespec_dumps(spec): Expected `spec` to be instance of "
872
f"TreeSpec but got item of type {type(treespec)}."
874
from ._pytree import (
875
tree_structure as _tree_structure,
876
treespec_dumps as _treespec_dumps,
879
orig_treespec = _tree_structure(tree_unflatten([0] * treespec.num_leaves, treespec))
880
return _treespec_dumps(orig_treespec, protocol=protocol)
883
def treespec_loads(serialized: str) -> TreeSpec:
884
"""Deserialize a treespec from a JSON string."""
885
from ._pytree import (
886
tree_unflatten as _tree_unflatten,
887
treespec_loads as _treespec_loads,
890
orig_treespec = _treespec_loads(serialized)
891
dummy_tree = _tree_unflatten([0] * orig_treespec.num_leaves, orig_treespec)
892
treespec = tree_structure(dummy_tree)
897
def __repr__(self) -> str:
901
def treespec_pprint(treespec: TreeSpec) -> str:
902
dummy_tree = tree_unflatten(
903
[_DummyLeaf() for _ in range(treespec.num_leaves)],
906
return repr(dummy_tree)
909
class LeafSpecMeta(type(TreeSpec)): # type: ignore[misc]
910
def __instancecheck__(self, instance: object) -> bool:
911
return isinstance(instance, TreeSpec) and instance.is_leaf()
914
class LeafSpec(TreeSpec, metaclass=LeafSpecMeta):
915
def __new__(cls) -> "LeafSpec":
916
return optree.treespec_leaf(none_is_leaf=True) # type: ignore[return-value]
919
def tree_flatten_with_path(
921
is_leaf: Optional[Callable[[PyTree], bool]] = None,
922
) -> Tuple[List[Tuple[KeyPath, Any]], TreeSpec]:
923
"""Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path.
926
tree: a pytree to flatten. If it contains a custom type, that type must be
927
registered with an appropriate `tree_flatten_with_path_fn` when registered
928
with :func:`register_pytree_node`.
929
is_leaf: An extra leaf predicate function that will be called at each
930
flattening step. The function should have a single argument with signature
931
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
932
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
933
leaf or not. If the function is not specified, the default pytree registry will be used.
935
A tuple where the first element is a list of (key path, leaf) pairs, and the
936
second element is a :class:`TreeSpec` representing the structure of the flattened
939
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
942
def tree_leaves_with_path(
944
is_leaf: Optional[Callable[[PyTree], bool]] = None,
945
) -> List[Tuple[KeyPath, Any]]:
946
"""Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
949
tree: a pytree. If it contains a custom type, that type must be
950
registered with an appropriate `tree_flatten_with_path_fn` when registered
951
with :func:`register_pytree_node`.
952
is_leaf: An extra leaf predicate function that will be called at each
953
flattening step. The function should have a single argument with signature
954
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
955
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
956
leaf or not. If the function is not specified, the default pytree registry will be used.
958
A list of (key path, leaf) pairs.
960
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
963
def tree_map_with_path(
964
func: Callable[..., Any],
967
is_leaf: Optional[Callable[[PyTree], bool]] = None,
969
"""Like :func:`tree_map`, but the provided callable takes an additional key path argument.
972
func: A function that takes ``2 + len(rests)`` arguments, to be applied at the
973
corresponding leaves of the pytrees. The first positional argument
974
to ``func`` is the key path of the leaf in question. The second
975
positional argument is the value of the leaf.
976
tree: A pytree to be mapped over, with each leaf providing the first positional
977
argument to function ``func``.
978
rests: A tuple of pytrees, each of which has the same structure as
979
``tree`` or has ``tree`` as a prefix.
980
is_leaf: An extra leaf predicate function that will be called at each
981
flattening step. The function should have a single argument with signature
982
``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
983
as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
984
leaf or not. If the function is not specified, the default pytree registry will be used.
987
A new pytree with the same structure as ``tree`` but with the value at each leaf given by
988
``func(keypath, x, *xs)`` where ``keypath`` is the key path at the
989
corresponding leaf in ``tree``, ``x`` is the value at that leaf, and
990
``xs`` is the tuple of values at corresponding nodes in ``rests``.
992
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
995
def keystr(kp: KeyPath) -> str:
996
"""Given a key path, return a pretty-printed representation."""
997
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
1000
def key_get(obj: Any, kp: KeyPath) -> Any:
1001
"""Given an object and a key path, return the value at the key path."""
1002
raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
1005
_pytree._cxx_pytree_imported = True
1006
for args, kwargs in _pytree._cxx_pytree_pending_imports:
1007
_private_register_pytree_node(*args, **kwargs)